diff --git a/backend/tests/test_service.py b/backend/tests/test_service.py index 08974b1..0248e18 100644 --- a/backend/tests/test_service.py +++ b/backend/tests/test_service.py @@ -4,7 +4,7 @@ from collections.abc import Generator from contextlib import contextmanager from datetime import datetime, timedelta, timezone from types import SimpleNamespace -from unittest.mock import ANY, patch +from unittest.mock import ANY, MagicMock, patch import pytest from fastapi import FastAPI, status @@ -12,14 +12,17 @@ from fastapi.requests import Request from fastapi.responses import Response from settings import settings -from trading_journal import service +from trading_journal import dto, service +from trading_journal.crud import Session # --- Auth middleware --------------------------------------------------------- class FakeDBFactory: @contextmanager - def get_session_ctx_manager(self) -> Generator[SimpleNamespace, None, None]: - yield SimpleNamespace(name="fakesession") + def get_session_ctx_manager(self) -> Generator[Session, None, None]: + fake_session = MagicMock(spec=Session) + fake_session.name = "FakeDBSession" + yield fake_session def verify_json_response(response: Response, expected_status: int, expected_detail: str) -> None: @@ -189,7 +192,7 @@ def test_auth_middleware_reject_inactive_user() -> None: ) with ( - patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash) as mock_hash, + patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), ): response = asyncio.run(middleware.dispatch(request, call_next)) @@ -197,57 +200,902 @@ def test_auth_middleware_reject_inactive_user() -> None: verify_json_response(response, status.HTTP_401_UNAUTHORIZED, "Unauthorized") +def test_auth_middleware_allows_valid_token_and_no_update_expires() -> None: + app = FastAPI() + app.state.db_factory = FakeDBFactory() + middleware = service.AuthMiddleWare(app) + fake_token_orig = "validtoken" + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [(b"cookie", f"session_token={fake_token_orig}".encode()), (b"user-agent", b"test-agent")], + "client": ("testclient", 50000), + "app": app, + } + request = Request(scope) + + async def call_next(req: Request, expected: Request = request) -> Response: + assert req is expected + assert hasattr(req.state, "user_id") + assert req.state.user_id == 1 + return Response(status_code=status.HTTP_204_NO_CONTENT) + + active_user = SimpleNamespace( + id=1, + username="activeuser", + is_active=True, + ) + valid_session = SimpleNamespace( + id=1, + user_id=1, + session_token_hash="validtokenhash", + expires_at=(datetime.now(timezone.utc) + timedelta(days=1)), + user=active_user, + ) + + with ( + patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), + patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), + patch("trading_journal.crud.update_login_session") as mock_update, + ): + response = asyncio.run(middleware.dispatch(request, call_next)) + assert response.status_code == status.HTTP_204_NO_CONTENT + mock_update.assert_called_once() + _, kwargs = mock_update.call_args + update_session = kwargs.get("update_session") + assert update_session is not None + assert update_session.expires_at == valid_session.expires_at + + +def test_auth_middleware_allows_valid_token_and_updates_expires() -> None: + app = FastAPI() + app.state.db_factory = FakeDBFactory() + middleware = service.AuthMiddleWare(app) + fake_token_orig = "validtoken" + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [(b"cookie", f"session_token={fake_token_orig}".encode()), (b"user-agent", b"test-agent")], + "client": ("testclient", 50000), + "app": app, + } + request = Request(scope) + + async def call_next(req: Request, expected: Request = request) -> Response: + assert req is expected + assert hasattr(req.state, "user_id") + assert req.state.user_id == 1 + return Response(status_code=status.HTTP_204_NO_CONTENT) + + active_user = SimpleNamespace( + id=1, + username="activeuser", + is_active=True, + ) + valid_session = SimpleNamespace( + id=1, + user_id=1, + session_token_hash="validtokenhash", + expires_at=(datetime.now(timezone.utc) + timedelta(minutes=10)), + user=active_user, + ) + + with ( + patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), + patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), + patch("trading_journal.crud.update_login_session") as mock_update, + ): + response = asyncio.run(middleware.dispatch(request, call_next)) + assert response.status_code == status.HTTP_204_NO_CONTENT + mock_update.assert_called_once() + _, kwargs = mock_update.call_args + update_session = kwargs.get("update_session") + assert update_session is not None + assert (update_session.expires_at - datetime.now(timezone.utc)).total_seconds() > settings.session_expiry_seconds - 1 + assert (update_session.last_seen_at - datetime.now(timezone.utc)).total_seconds() < 1 + assert update_session.last_used_ip == "testclient" + assert update_session.user_agent == "test-agent" + + # --- User services ----------------------------------------------------------- -def test_register_user_success(): - pytest.fail("TODO: mock crud/security, assert UserRead username") +def test_register_user_success() -> None: + user_in = dto.UserCreate(username="newuser", password="newpassword") + user_in_with_hashed_password = { + "username": user_in.username, + "password_hash": "hashednewpassword", + } + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_user_by_username", return_value=None) as mock_get, + patch( + "trading_journal.crud.create_user", + return_value=SimpleNamespace(id=1, username=user_in.username, is_active=True), + ) as mock_create, + patch("trading_journal.security.hash_password", return_value=user_in_with_hashed_password["password_hash"]), + ): + user_out = service.register_user_service(db, user_in) + assert user_out.id is not None + assert user_out.username == user_in.username + mock_get.assert_called_once_with(db, user_in.username) + mock_create.assert_called_once_with(db, user_data=user_in_with_hashed_password) -def test_register_user_exists_raises(): - pytest.fail("TODO: mock get_user_by_username to return obj and expect UserAlreadyExistsError") +def test_register_user_exists_raises() -> None: + user_in = dto.UserCreate(username="existinguser", password="newpassword") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_user_by_username", + return_value=SimpleNamespace(id=1, username=user_in.username, is_active=True), + ) as mock_get, + ): + with pytest.raises(service.UserAlreadyExistsError) as exc_info: + service.register_user_service(db, user_in) + assert str(exc_info.value) == "username already exists" + mock_get.assert_called_once_with(db, user_in.username) -def test_authenticate_user_success(): - pytest.fail("TODO: mock crud/security, expect token + SessionsCreate DTO") +def test_authenticate_user_success() -> None: + user_in = dto.UserLogin(username="validuser", password="validpassword") + stored_user = SimpleNamespace(id=1, username=user_in.username, is_active=True, password_hash="hashedpassword") + expected_login_session = dto.SessionsCreate( + user_id=stored_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(seconds=settings.session_expiry_seconds), + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_user_by_username", + return_value=stored_user, + ) as mock_get, + patch("trading_journal.security.verify_password", return_value=True) as mock_verify, + patch("trading_journal.security.generate_session_token", return_value="newsessiontoken") as mock_token, + patch("trading_journal.security.hash_session_token_sha256", return_value="newsessiontokenhash") as mock_hash_session_token, + patch( + "trading_journal.crud.create_login_session", + return_value=SimpleNamespace(user_id=stored_user.id, expires_at=expected_login_session.expires_at), + ) as mock_create_session, + ): + user_out = service.authenticate_user_service(db, user_in) + assert user_out is not None + login_session, token = user_out + # assert fields instead of direct equality to avoid pydantic/model issues + assert getattr(login_session, "user_id", None) == stored_user.id + assert isinstance(getattr(login_session, "expires_at", None), datetime) + assert abs((login_session.expires_at - expected_login_session.expires_at).total_seconds()) < 2 + assert token == "newsessiontoken" + assert login_session.user_id == stored_user.id + mock_get.assert_called_once_with(db, user_in.username) + mock_verify.assert_called_once_with(user_in.password, stored_user.password_hash) + mock_token.assert_called_once() + mock_hash_session_token.assert_called_once_with("newsessiontoken") + mock_create_session.assert_called_once_with( + session=db, + user_id=stored_user.id, + session_token_hash="newsessiontokenhash", + session_length_seconds=settings.session_expiry_seconds, + ) -def test_authenticate_user_invalid_password_returns_none(): - pytest.fail("TODO: mock verify_password False") +def test_authenticate_user_not_found_returns_none() -> None: + user_in = dto.UserLogin(username="nonexistentuser", password="anypassword") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_user_by_username", + return_value=None, + ) as mock_get, + ): + user_out = service.authenticate_user_service(db, user_in) + assert user_out is None + mock_get.assert_called_once_with(db, user_in.username) + + +def test_authenticate_user_invalid_password_returns_none() -> None: + user_in = dto.UserLogin(username="validuser", password="invalidpassword") + stored_user = SimpleNamespace(id=1, username=user_in.username, is_active=True, password_hash="hashedpassword") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_user_by_username", + return_value=stored_user, + ) as mock_get, + patch("trading_journal.security.verify_password", return_value=False) as mock_verify, + ): + user_out = service.authenticate_user_service(db, user_in) + assert user_out is None + mock_get.assert_called_once_with(db, user_in.username) + mock_verify.assert_called_once_with(user_in.password, stored_user.password_hash) # --- Exchange services ------------------------------------------------------- -def test_create_exchange_duplicate_raises(): - pytest.fail("TODO: mock get_exchange_by_name_and_user_id and expect ExchangeAlreadyExistsError") +def test_create_exchange_duplicate_raises() -> None: + exchange_in = dto.ExchangesCreate(user_id=1, name="NYSE", notes="Test exchange") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_name_and_user_id", + return_value=SimpleNamespace(id=1, user_id=1, name=exchange_in.name, notes="Existing exchange"), + ) as mock_get, + ): + with pytest.raises(service.ExchangeAlreadyExistsError) as exc_info: + service.create_exchange_service(db, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes) + assert str(exc_info.value) == "Exchange with the same name already exists for this user" + mock_get.assert_called_once_with(db, exchange_in.name, exchange_in.user_id) -def test_update_exchange_not_found(): - pytest.fail("TODO: mock get_exchange_by_id None and expect ExchangeNotFoundError") +def test_create_exchange_success() -> None: + exchange_in = dto.ExchangesCreate(user_id=1, name="NASDAQ", notes="New exchange") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_name_and_user_id", + return_value=None, + ) as mock_get, + patch( + "trading_journal.crud.create_exchange", + return_value=SimpleNamespace(id=2, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes), + ) as mock_create, + ): + exchange_out = service.create_exchange_service(db, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes) + assert exchange_out.name == exchange_in.name + assert exchange_out.notes == exchange_in.notes + mock_get.assert_called_once_with(db, exchange_in.name, exchange_in.user_id) + mock_create.assert_called_once_with(db, exchange_data=exchange_in) + + +def test_get_exchanges_by_user_id() -> None: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_all_exchanges_by_user_id", + return_value=[ + SimpleNamespace(id=1, user_id=1, name="NYSE", notes="First exchange"), + SimpleNamespace(id=2, user_id=1, name="NASDAQ", notes="Second exchange"), + ], + ) as mock_get, + ): + exchanges = service.get_exchanges_by_user_service(db, user_id=1) + assert len(exchanges) == 2 + assert exchanges[0].name == "NYSE" + assert exchanges[1].name == "NASDAQ" + mock_get.assert_called_once_with(db, 1) + + +def test_get_exchanges_by_user_no_exchanges() -> None: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_all_exchanges_by_user_id", + return_value=[], + ) as mock_get, + ): + exchanges = service.get_exchanges_by_user_service(db, user_id=1) + assert len(exchanges) == 0 + mock_get.assert_called_once_with(db, 1) + + +def test_update_exchange_not_found() -> None: + exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_id", + return_value=None, + ) as mock_get, + ): + with pytest.raises(service.ExchangeNotFoundError) as exc_info: + service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + assert str(exc_info.value) == "Exchange not found" + mock_get.assert_called_once_with(db, 1) + + +def test_update_exchange_owner_mismatch_raises() -> None: + exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") + existing_exchange = SimpleNamespace(id=1, user_id=2, name="OldName", notes="Old notes") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_id", + return_value=existing_exchange, + ) as mock_get, + ): + with pytest.raises(service.ExchangeNotFoundError) as exc_info: + service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + assert str(exc_info.value) == "Exchange not found" + mock_get.assert_called_once_with(db, 1) + + +def test_update_exchange_duplication() -> None: + exchange_update = dto.ExchangesBase(name="DuplicateName", notes="Updated notes") + existing_exchange = SimpleNamespace(id=1, user_id=1, name="OldName", notes="Old notes") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_id", + return_value=existing_exchange, + ) as mock_get, + patch( + "trading_journal.crud.get_exchange_by_name_and_user_id", + return_value=SimpleNamespace(id=2, user_id=1, name="DuplicateName", notes="Another exchange"), + ) as mock_get_by_name, + ): + with pytest.raises(service.ExchangeAlreadyExistsError) as exc_info: + service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + assert str(exc_info.value) == "Another exchange with the same name already exists for this user" + mock_get.assert_called_once_with(db, 1) + mock_get_by_name.assert_called_once_with(db, "DuplicateName", 1) + + +def test_update_exchange_success() -> None: + exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") + existing_exchange = SimpleNamespace(id=1, user_id=1, name="OldName", notes="Old notes") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_id", + return_value=existing_exchange, + ) as mock_get, + patch( + "trading_journal.crud.get_exchange_by_name_and_user_id", + return_value=None, + ) as mock_get_by_name, + patch( + "trading_journal.crud.update_exchange", + return_value=SimpleNamespace(id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes), + ) as mock_update, + ): + exchange_out = service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + assert exchange_out.name == exchange_update.name + assert exchange_out.notes == exchange_update.notes + mock_get.assert_called_once_with(db, 1) + mock_get_by_name.assert_called_once_with(db, "UpdatedName", 1) + mock_update.assert_called_once_with(db, 1, update_data=exchange_update) # --- Cycle services ---------------------------------------------------------- -def test_validate_cycle_update_rules(): - pytest.fail("TODO: call _validate_cycle_update_data with invalid combos") +def test_get_cycle_by_id_not_found_raises() -> None: + user_id = 1 + cycle_id = 1 + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=None) as mock_get, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.get_cycle_by_id_service(db, user_id=user_id, cycle_id=cycle_id) + assert str(exc_info.value) == "Cycle not found" + mock_get.assert_called_once_with(db, cycle_id) -def test_update_cycle_owner_mismatch_raises(): - pytest.fail("TODO: mock get_cycle_by_id owned by other user, expect CycleNotFoundError") +def test_get_cycle_by_id_owner_mismatch_raises() -> None: + user_id = 1 + cycle_id = 1 + cycle = SimpleNamespace(id=cycle_id, user_id=2) # Owned by different user + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=cycle) as mock_get, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.get_cycle_by_id_service(db, user_id=user_id, cycle_id=cycle_id) + assert str(exc_info.value) == "Cycle not found" + mock_get.assert_called_once_with(db, cycle_id) + + +def test_get_cycle_by_id_success() -> None: + user_id = 1 + cycle_id = 1 + cycle = SimpleNamespace( + id=cycle_id, + friendly_name="Test Cycle", + status="OPEN", + funding_source="MIXED", + user_id=user_id, + symbol="AAPL", + exchange_id=1, + underlying_currency="USD", + start_date=datetime.now(timezone.utc).date(), + trades=[], + exchange=None, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=cycle) as mock_get, + ): + cycle_out = service.get_cycle_by_id_service(db, user_id=user_id, cycle_id=cycle_id) + assert cycle_out.id == cycle_id + assert cycle_out.user_id == user_id + assert cycle_out.friendly_name == "Test Cycle" + assert cycle_out.status == "OPEN" + assert cycle_out.funding_source == "MIXED" + assert cycle_out.symbol == "AAPL" + assert cycle_out.exchange_id == 1 + assert cycle_out.underlying_currency == "USD" + assert cycle_out.trades == [] + mock_get.assert_called_once_with(db, cycle_id) + + +def test_get_cycles_by_user_no_cycles() -> None: + user_id = 1 + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycles_by_user_id", return_value=[]) as mock_get, + ): + cycles = service.get_cycles_by_user_service(db, user_id=user_id) + assert isinstance(cycles, list) + assert len(cycles) == 0 + mock_get.assert_called_once_with(db, user_id) + + +def test_get_cycles_by_user_with_cycles() -> None: + user_id = 1 + cycle1 = SimpleNamespace( + id=1, + friendly_name="Cycle 1", + status="OPEN", + funding_source="MIXED", + user_id=user_id, + symbol="AAPL", + exchange_id=1, + underlying_currency="USD", + start_date=datetime.now(timezone.utc).date(), + trades=[], + exchange=None, + ) + cycle2 = SimpleNamespace( + id=2, + friendly_name="Cycle 2", + status="CLOSED", + funding_source="LOAN", + user_id=user_id, + symbol="TSLA", + exchange_id=2, + underlying_currency="USD", + start_date=datetime.now(timezone.utc).date() - timedelta(days=30), + trades=[], + exchange=None, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycles_by_user_id", return_value=[cycle1, cycle2]) as mock_get, + ): + cycles = service.get_cycles_by_user_service(db, user_id=user_id) + assert isinstance(cycles, list) + assert len(cycles) == 2 + assert cycles[0].id == 1 + assert cycles[0].friendly_name == "Cycle 1" + assert cycles[1].id == 2 + assert cycles[1].friendly_name == "Cycle 2" + mock_get.assert_called_once_with(db, user_id) + + +def test_update_cycle_closed_status_mismatch_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="CLOSED") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "end_date is required when status is CLOSED" + + +def test_update_cycle_open_status_mismatch_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", end_date=datetime.now(timezone.utc).date()) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "end_date must be empty when status is OPEN" + + +def test_update_cycle_invalid_capital_exposure_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", capital_exposure_cents=-100) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "capital_exposure_cents must be non-negative" + + +def test_update_cycle_no_cash_no_loan_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", funding_source="LOAN", loan_amount_cents=None) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + + +def test_update_cycle_loan_missing_interest_raises() -> None: + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + funding_source="LOAN", + loan_amount_cents=10000, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + + +def test_update_cycle_loan_negative_loan_raises() -> None: + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + funding_source="LOAN", + loan_amount_cents=-10000, + loan_interest_rate_tenth_bps=50, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "loan_amount_cents must be non-negative" + + +def test_update_cycle_loan_negative_interest_raises() -> None: + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + funding_source="LOAN", + loan_amount_cents=10000, + loan_interest_rate_tenth_bps=-50, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "loan_interest_rate_tenth_bps must be non-negative" + + +def test_update_cycle_not_found_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=None) as mock_get, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "Cycle not found" + mock_get.assert_called_once_with(db, cycle_data.id) + + +def test_update_cycle_owner_mismatch_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN") + existing_cycle = SimpleNamespace(id=1, user_id=2) # Owned by different user + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=existing_cycle) as mock_get, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "Cycle not found" + mock_get.assert_called_once_with(db, cycle_data.id) + + +def test_update_cycle_success() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", funding_source="CASH", capital_exposure_cents=5000) + existing_cycle = SimpleNamespace( + id=1, + user_id=1, + friendly_name="Old Cycle", + symbol="AAPL", + exchange_id=1, + underlying_currency="USD", + start_date=datetime.now(timezone.utc).date(), + status="OPEN", + funding_source="MIXED", + capital_exposure_cents=10000, + loan_amount_cents=2000, + loan_interest_rate_tenth_bps=50, + ) + updated_cycle = SimpleNamespace( + id=1, + user_id=1, + symbol="AAPL", + exchange_id=1, + underlying_currency="USD", + start_date=existing_cycle.start_date, + friendly_name="Updated Cycle", + status=cycle_data.status, + funding_source=cycle_data.funding_source, + capital_exposure_cents=cycle_data.capital_exposure_cents, + loan_amount_cents=None, + loan_interest_rate_tenth_bps=None, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=existing_cycle) as mock_get, + patch("trading_journal.crud.update_cycle", return_value=updated_cycle) as mock_update, + ): + cycle_out = service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert cycle_out.id == updated_cycle.id + assert cycle_out.friendly_name == updated_cycle.friendly_name + assert cycle_out.status == updated_cycle.status + assert cycle_out.funding_source == updated_cycle.funding_source + assert cycle_out.capital_exposure_cents == updated_cycle.capital_exposure_cents + assert cycle_out.loan_amount_cents is None + assert cycle_out.loan_interest_rate_tenth_bps is None + mock_get.assert_called_once_with(db, cycle_data.id) + update_cycle_base = dto.CycleBase( + friendly_name=cycle_data.friendly_name, + status=cycle_data.status, + funding_source=cycle_data.funding_source, + capital_exposure_cents=cycle_data.capital_exposure_cents, + loan_amount_cents=getattr(cycle_data, "loan_amount_cents", None), + loan_interest_rate_tenth_bps=getattr(cycle_data, "loan_interest_rate_tenth_bps", None), + end_date=getattr(cycle_data, "end_date", None), + ) + mock_update.assert_called_once_with(db, cycle_data.id, update_data=update_cycle_base) # --- Trade services ---------------------------------------------------------- -def test_create_trade_invalid_sell_requires_expiry(): - pytest.fail("TODO: build SELL_PUT without expiry/strike, expect InvalidTradeDataError") +def test_create_trade_short_option_no_strike() -> None: + trade_data = dto.TradeCreate( + user_id=1, + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.SELL_PUT, + trade_strategy=dto.TradeStrategy.WHEEL, + trade_date=datetime.now(timezone.utc).date(), + quantity=-1, + price_cents=5000, + commission_cents=100, + cycle_id=1, + friendly_name="Short Call", + notes="Test trade", + quantity_multiplier=100, + expiry_date=datetime.now(timezone.utc).date() + timedelta(days=30), + strike_price_cents=None, # Missing strike price + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidTradeDataError) as exc_info: + service.create_trade_service(db, 1, trade_data) + assert str(exc_info.value) == "Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades" -def test_create_trade_appends_cashflow_and_calls_crud(): - pytest.fail("TODO: mock crud.create_trade, assert net_cash_flow_cents in result") +def test_create_trade_success() -> None: + trade_data = dto.TradeCreate( + user_id=1, + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.SELL_PUT, + trade_strategy=dto.TradeStrategy.WHEEL, + trade_date=datetime.now(timezone.utc).date(), + strike_price_cents=15000, + expiry_date=datetime.now(timezone.utc).date() + timedelta(days=30), + quantity=1, + price_cents=5000, + commission_cents=100, + cycle_id=1, + friendly_name="Sell put", + notes="Test trade", + quantity_multiplier=1, + ) + created_trade = SimpleNamespace( + id=1, + user_id=trade_data.user_id, + symbol=trade_data.symbol, + exchange_id=trade_data.exchange_id, + underlying_currency=trade_data.underlying_currency, + trade_type=trade_data.trade_type, + trade_strategy=trade_data.trade_strategy, + trade_date=trade_data.trade_date, + quantity=trade_data.quantity, + price_cents=trade_data.price_cents, + commission_cents=trade_data.commission_cents, + cycle_id=trade_data.cycle_id, + friendly_name=trade_data.friendly_name, + notes=trade_data.notes, + quantity_multiplier=trade_data.quantity_multiplier, + expiry_date=trade_data.expiry_date, + strike_price_cents=trade_data.strike_price_cents, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.create_trade", return_value=created_trade) as mock_create_trade, + ): + trade_out = service.create_trade_service(db, user_id=1, trade_data=trade_data) + assert trade_out.id == created_trade.id + assert trade_out.user_id == created_trade.user_id + assert trade_out.symbol == created_trade.symbol + assert trade_out.trade_type == created_trade.trade_type + mock_create_trade.assert_called_once() + _, kwargs = mock_create_trade.call_args + passed_trade = kwargs.get("trade_data") or (mock_create_trade.call_args[0][1] if len(mock_create_trade.call_args[0]) > 1 else None) + assert passed_trade is not None + # expected for SELL_PUT: gross = quantity * price * quantity_multiplier (positive), net = gross - commission + expected_gross = trade_data.quantity * trade_data.price_cents * (trade_data.quantity_multiplier or 1) + expected_net = expected_gross - trade_data.commission_cents + assert getattr(passed_trade, "gross_cash_flow_cents", None) == expected_gross + assert getattr(passed_trade, "net_cash_flow_cents", None) == expected_net -def test_get_trade_by_id_missing_raises(): - pytest.fail("TODO: mock get_trade_by_id None, expect TradeNotFoundError") +def test_get_trade_by_id_not_found_when_missing() -> None: + with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.get_trade_by_id_service(db, user_id=1, trade_id=1) + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 1) -def test_update_trade_friendly_name_not_found(): - pytest.fail("TODO: mock get_trade_by_id None, expect TradeNotFoundError") +def test_get_trade_by_id_not_found_owner_mismatch() -> None: + existing_trade = SimpleNamespace(id=2, user_id=2) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + ): + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.get_trade_by_id_service(db, user_id=1, trade_id=2) + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 2) -def test_update_trade_note_sets_empty_string_when_none(): - pytest.fail("TODO: mock update_trade_note to return note '', assert DTO note") +def test_get_trade_by_id_success() -> None: + # build a trade-like object compatible with dto.TradeRead/model_validate + trade_obj = SimpleNamespace( + id=10, + user_id=1, + friendly_name="Test Trade", + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.LONG_SPOT, + trade_strategy=dto.TradeStrategy.SPOT, + trade_date=datetime.now(timezone.utc).date(), + trade_time_utc=None, + expiry_date=None, + strike_price_cents=None, + quantity=1, + quantity_multiplier=1, + price_cents=1000, + gross_cash_flow_cents=-1000, + commission_cents=10, + net_cash_flow_cents=-1010, + is_invalidated=False, + invalidated_at=None, + replaced_by_trade_id=None, + notes="ok", + cycle_id=None, + ) + with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=trade_obj) as mock_get: + res = service.get_trade_by_id_service(db, user_id=1, trade_id=10) + assert res.id == trade_obj.id + assert res.user_id == trade_obj.user_id + assert res.symbol == trade_obj.symbol + assert res.trade_type == trade_obj.trade_type + mock_get.assert_called_once_with(db, 10) + + +def test_update_trade_friendly_name_not_found() -> None: + with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Name") + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 10) + + +def test_update_trade_friendly_name_owner_mismatch_raises() -> None: + existing_trade = SimpleNamespace(id=10, user_id=2) # owned by another user + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + ): + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Name") + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 10) + + +def test_update_trade_friendly_name_success() -> None: + existing_trade = SimpleNamespace( + id=10, + user_id=1, + friendly_name="Old Name", + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.LONG_SPOT, + trade_strategy=dto.TradeStrategy.SPOT, + trade_date=datetime.now(timezone.utc).date(), + trade_time_utc=None, + expiry_date=None, + strike_price_cents=None, + quantity=1, + quantity_multiplier=1, + price_cents=1000, + gross_cash_flow_cents=-1000, + commission_cents=10, + net_cash_flow_cents=-1010, + is_invalidated=False, + invalidated_at=None, + replaced_by_trade_id=None, + notes="ok", + cycle_id=None, + ) + updated_trade = SimpleNamespace(**{**existing_trade.__dict__, "friendly_name": "New Friendly"}) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + patch("trading_journal.crud.update_trade_friendly_name", return_value=updated_trade) as mock_update, + ): + res = service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Friendly") + assert res.friendly_name == "New Friendly" + mock_get.assert_called_once_with(db, 10) + mock_update.assert_called_once_with(db, 10, "New Friendly") + + +def test_update_trade_note_not_found() -> None: + with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.update_trade_note_service(db, user_id=1, trade_id=20, note="x") + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 20) + + +def test_update_trade_note_owner_mismatch_raises() -> None: + existing_trade = SimpleNamespace(id=20, user_id=2) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + ): + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.update_trade_note_service(db, user_id=1, trade_id=20, note="x") + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 20) + + +def test_update_trade_note_success_and_none_becomes_empty() -> None: + existing_trade = SimpleNamespace( + id=20, + user_id=1, + friendly_name="Trade", + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.LONG_SPOT, + trade_strategy=dto.TradeStrategy.SPOT, + trade_date=datetime.now(timezone.utc).date(), + trade_time_utc=None, + expiry_date=None, + strike_price_cents=None, + quantity=1, + quantity_multiplier=1, + price_cents=1000, + gross_cash_flow_cents=-1000, + commission_cents=10, + net_cash_flow_cents=-1010, + is_invalidated=False, + invalidated_at=None, + replaced_by_trade_id=None, + notes="old", + cycle_id=None, + ) + updated_trade = SimpleNamespace(**{**existing_trade.__dict__, "notes": ""}) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + patch("trading_journal.crud.update_trade_note", return_value=updated_trade) as mock_update, + ): + res = service.update_trade_note_service(db, user_id=1, trade_id=20, note=None) + assert res.notes == "" + mock_get.assert_called_once_with(db, 20) + mock_update.assert_called_once_with(db, 20, "") diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 4855c3e..79d487e 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -229,6 +229,7 @@ def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int # Cycle Service def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead: + raise NotImplementedError("Cycle creation not implemented") cycle_data_dict = cycle_data.model_dump() cycle_data_dict["user_id"] = user_id cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict)