diff --git a/backend/ruff.toml b/backend/ruff.toml index 7571904..223ab8e 100644 --- a/backend/ruff.toml +++ b/backend/ruff.toml @@ -24,3 +24,4 @@ ignore = [ [lint.extend-per-file-ignores] "test*.py" = ["S101", "S105", "S106", "PT011", "PLR2004"] "models*.py" = ["FA102"] +"dto.py" = ["TC001", "TC003"] diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 9f343f6..fbbcc43 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -56,7 +56,9 @@ def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int: return cast("int", exchange.id) -def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int: +def make_cycle( + session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle" +) -> int: cycle = models.Cycles( user_id=user_id, friendly_name=friendly_name, @@ -72,7 +74,9 @@ def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: return cast("int", cycle.id) -def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int: +def make_trade( + session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade" +) -> int: cycle: models.Cycles | None = session.get(models.Cycles, cycle_id) assert cycle is not None exchange_id = cycle.exchange_id @@ -137,13 +141,17 @@ def _ensure_utc_aware(dt: datetime | None) -> datetime | None: return dt.astimezone(timezone.utc) -def _validate_timestamp(actual: datetime, expected: datetime, tolerance: timedelta) -> None: +def _validate_timestamp( + actual: datetime, expected: datetime, tolerance: timedelta +) -> None: actual_utc = _ensure_utc_aware(actual) expected_utc = _ensure_utc_aware(expected) assert actual_utc is not None assert expected_utc is not None delta = abs(actual_utc - expected_utc) - assert delta <= tolerance, f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}" + assert delta <= tolerance, ( + f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}" + ) # Trades @@ -470,7 +478,9 @@ def test_update_trade_friendly_name(session: Session) -> None: trade_id = make_trade(session, user_id, cycle_id) new_friendly_name = "Updated Trade Name" - updated_trade = crud.update_trade_friendly_name(session, trade_id, new_friendly_name) + updated_trade = crud.update_trade_friendly_name( + session, trade_id, new_friendly_name + ) assert updated_trade is not None assert updated_trade.id == trade_id assert updated_trade.friendly_name == new_friendly_name @@ -624,7 +634,9 @@ def test_get_cycles_by_user_id(session: Session) -> None: def test_update_cycle(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) - cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") + cycle_id = make_cycle( + session, user_id, exchange_id, friendly_name="Initial Cycle Name" + ) update_data = { "friendly_name": "Updated Cycle Name", @@ -646,14 +658,20 @@ def test_update_cycle(session: Session) -> None: def test_update_cycle_immutable_fields(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) - cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") + cycle_id = make_cycle( + session, user_id, exchange_id, friendly_name="Initial Cycle Name" + ) # Attempt to update immutable fields update_data = { "id": cycle_id + 1, # Trying to change the ID "user_id": user_id + 1, # Trying to change the user_id - "start_date": datetime(2020, 1, 1, tzinfo=timezone.utc).date(), # Trying to change start_date - "created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at + "start_date": datetime( + 2020, 1, 1, tzinfo=timezone.utc + ).date(), # Trying to change start_date + "created_at": datetime( + 2020, 1, 1, tzinfo=timezone.utc + ), # Trying to change created_at "friendly_name": "Valid Update", # Valid field to update } @@ -685,7 +703,10 @@ def test_create_cycle_loan_event(session: Session) -> None: assert loan_event.id is not None assert loan_event.cycle_id == cycle_id assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"] - assert loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"] + assert ( + loan_event.loan_interest_rate_tenth_bps + == loan_data["loan_interest_rate_tenth_bps"] + ) assert loan_event.notes == loan_data["notes"] assert loan_event.effective_date == now.date() _validate_timestamp(loan_event.created_at, now, timedelta(seconds=1)) @@ -695,12 +716,41 @@ def test_create_cycle_loan_event(session: Session) -> None: assert actual_loan_event is not None assert actual_loan_event.cycle_id == cycle_id assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"] - assert actual_loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"] + assert ( + actual_loan_event.loan_interest_rate_tenth_bps + == loan_data["loan_interest_rate_tenth_bps"] + ) assert actual_loan_event.notes == loan_data["notes"] assert actual_loan_event.effective_date == now.date() _validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1)) +def test_create_cycle_loan_event_same_date_error(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + loan_data_1 = { + "cycle_id": cycle_id, + "loan_amount_cents": 100000, + "loan_interest_rate_tenth_bps": 5000, + "effective_date": datetime(2023, 1, 1, tzinfo=timezone.utc).date(), + "notes": "First loan event", + } + loan_data_2 = { + "cycle_id": cycle_id, + "loan_amount_cents": 150000, + "loan_interest_rate_tenth_bps": 4500, + "effective_date": datetime(2023, 1, 1, tzinfo=timezone.utc).date(), + "notes": "Second loan event same date", + } + + crud.create_cycle_loan_event(session, loan_data_1) + with pytest.raises(ValueError) as excinfo: + crud.create_cycle_loan_event(session, loan_data_2) + assert "create_cycle_loan_event integrity error" in str(excinfo.value) + + def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) @@ -729,34 +779,77 @@ def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None: notes = [event.notes for event in loan_events] assert loan_events[0].notes == loan_data_2["notes"] assert loan_events[0].effective_date == yesterday - assert notes == ["Second loan event", "First loan event"] # Ordered by effective_date desc + assert notes == [ + "Second loan event", + "First loan event", + ] # Ordered by effective_date desc -def test_get_cycle_loan_events_by_cycle_id_same_date(session: Session) -> None: +def test_get_cycle_loan_event_by_cycle_id_and_effective_date(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) - loan_data_1 = { + effective_date = datetime(2023, 1, 1, tzinfo=timezone.utc).date() + loan_data = { "cycle_id": cycle_id, "loan_amount_cents": 100000, "loan_interest_rate_tenth_bps": 5000, - "notes": "First loan event", + "effective_date": effective_date, + "notes": "Loan event for specific date", } - loan_data_2 = { + + crud.create_cycle_loan_event(session, loan_data) + loan_event = crud.get_loan_event_by_cycle_id_and_effective_date( + session, cycle_id, effective_date + ) + assert loan_event is not None + assert loan_event.cycle_id == cycle_id + assert loan_event.effective_date == effective_date + assert loan_event.notes == loan_data["notes"] + + +def test_update_cycle_loan_event(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + loan_data = { "cycle_id": cycle_id, - "loan_amount_cents": 150000, - "loan_interest_rate_tenth_bps": 4500, - "notes": "Second loan event", + "loan_amount_cents": 100000, + "loan_interest_rate_tenth_bps": 5000, + "notes": "Initial loan event", } - crud.create_cycle_loan_event(session, loan_data_1) - crud.create_cycle_loan_event(session, loan_data_2) + loan_event = crud.create_cycle_loan_event(session, loan_data) + assert loan_event is not None - loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id) - assert len(loan_events) == 2 - notes = [event.notes for event in loan_events] - assert notes == ["First loan event", "Second loan event"] # Ordered by id desc when effective_date is same + update_data = { + "loan_amount_cents": 120000, + "loan_interest_rate_tenth_bps": 4500, + "notes": "Updated loan event", + } + event_id = loan_event.id or 0 + + updated_loan_event = crud.update_cycle_loan_event(session, event_id, update_data) + assert updated_loan_event is not None + assert updated_loan_event.id == loan_event.id + assert updated_loan_event.loan_amount_cents == update_data["loan_amount_cents"] + assert ( + updated_loan_event.loan_interest_rate_tenth_bps + == update_data["loan_interest_rate_tenth_bps"] + ) + assert updated_loan_event.notes == update_data["notes"] + + session.refresh(updated_loan_event) + actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id) + assert actual_loan_event is not None + assert actual_loan_event.loan_amount_cents == update_data["loan_amount_cents"] + assert ( + actual_loan_event.loan_interest_rate_tenth_bps + == update_data["loan_interest_rate_tenth_bps"] + ) + assert actual_loan_event.notes == update_data["notes"] def test_create_cycle_loan_event_single_field(session: Session) -> None: @@ -802,7 +895,12 @@ def test_create_cycle_daily_accrual(session: Session) -> None: "notes": "Daily interest accrual", } - accrual = crud.create_cycle_daily_accrual(session, cycle_id, accrual_data["accrual_date"], accrual_data["accrued_interest_cents"]) + accrual = crud.create_cycle_daily_accrual( + session, + cycle_id, + accrual_data["accrual_date"], + accrual_data["accrued_interest_cents"], + ) assert accrual.id is not None assert accrual.cycle_id == cycle_id assert accrual.accrual_date == accrual_data["accrual_date"] @@ -835,8 +933,18 @@ def test_get_cycle_daily_accruals_by_cycle_id(session: Session) -> None: "accrued_interest_cents": 150, } - crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"]) - crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"]) + crud.create_cycle_daily_accrual( + session, + cycle_id, + accrual_data_1["accrual_date"], + accrual_data_1["accrued_interest_cents"], + ) + crud.create_cycle_daily_accrual( + session, + cycle_id, + accrual_data_2["accrual_date"], + accrual_data_2["accrued_interest_cents"], + ) accruals = crud.get_cycle_daily_accruals_by_cycle_id(session, cycle_id) assert len(accruals) == 2 @@ -863,18 +971,37 @@ def test_get_cycle_daily_accruals_by_cycle_id_and_date(session: Session) -> None "accrued_interest_cents": 150, } - crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"]) - crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"]) + crud.create_cycle_daily_accrual( + session, + cycle_id, + accrual_data_1["accrual_date"], + accrual_data_1["accrued_interest_cents"], + ) + crud.create_cycle_daily_accrual( + session, + cycle_id, + accrual_data_2["accrual_date"], + accrual_data_2["accrued_interest_cents"], + ) - accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, today) + accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date( + session, cycle_id, today + ) assert accruals_today is not None assert accruals_today.accrual_date == today - assert accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"] + assert ( + accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"] + ) - accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, yesterday) + accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date( + session, cycle_id, yesterday + ) assert accruals_yesterday is not None assert accruals_yesterday.accrual_date == yesterday - assert accruals_yesterday.accrual_amount_cents == accrual_data_1["accrued_interest_cents"] + assert ( + accruals_yesterday.accrual_amount_cents + == accrual_data_1["accrued_interest_cents"] + ) # Exchanges @@ -1031,7 +1158,9 @@ def test_update_user_immutable_fields(session: Session) -> None: update_data = { "id": user_id + 1, # Trying to change the ID "username": "newusername", # Trying to change the username - "created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at + "created_at": datetime( + 2020, 1, 1, tzinfo=timezone.utc + ), # Trying to change created_at "password_hash": "validupdate", # Valid field to update } @@ -1065,7 +1194,9 @@ def test_create_login_session_with_invalid_user(session: Session) -> None: def test_get_login_session_by_token_and_user_id(session: Session) -> None: now = datetime.now(timezone.utc) created_session = make_login_session(session, now) - fetched_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id) + fetched_session = crud.get_login_session_by_token_hash_and_user_id( + session, created_session.session_token_hash, created_session.user_id + ) assert fetched_session is not None assert fetched_session.id == created_session.id assert fetched_session.user_id == created_session.user_id @@ -1075,7 +1206,9 @@ def test_get_login_session_by_token_and_user_id(session: Session) -> None: def test_get_login_session_by_token(session: Session) -> None: now = datetime.now(timezone.utc) created_session = make_login_session(session, now) - fetched_session = crud.get_login_session_by_token_hash(session, created_session.session_token_hash) + fetched_session = crud.get_login_session_by_token_hash( + session, created_session.session_token_hash + ) assert fetched_session is not None assert fetched_session.id == created_session.id assert fetched_session.user_id == created_session.user_id @@ -1090,9 +1223,13 @@ def test_update_login_session(session: Session) -> None: "last_seen_at": now + timedelta(hours=1), "last_used_ip": "192.168.1.1", } - updated_session = crud.update_login_session(session, created_session.session_token_hash, update_data) + updated_session = crud.update_login_session( + session, created_session.session_token_hash, update_data + ) assert updated_session is not None - assert _ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"] + assert ( + _ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"] + ) assert updated_session.last_used_ip == update_data["last_used_ip"] @@ -1101,5 +1238,7 @@ def test_delete_login_session(session: Session) -> None: created_session = make_login_session(session, now) crud.delete_login_session(session, created_session.session_token_hash) - deleted_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id) + deleted_session = crud.get_login_session_by_token_hash_and_user_id( + session, created_session.session_token_hash, created_session.user_id + ) assert deleted_session is None diff --git a/backend/tests/test_service.py b/backend/tests/test_service.py index 0248e18..cf0fc8b 100644 --- a/backend/tests/test_service.py +++ b/backend/tests/test_service.py @@ -25,9 +25,15 @@ class FakeDBFactory: yield fake_session -def verify_json_response(response: Response, expected_status: int, expected_detail: str) -> None: +def verify_json_response( + response: Response, expected_status: int, expected_detail: str +) -> None: assert response.status_code == expected_status - body_bytes = response.body.tobytes() if isinstance(response.body, memoryview) else response.body + body_bytes = ( + response.body.tobytes() + if isinstance(response.body, memoryview) + else response.body + ) body_text = body_bytes.decode("utf-8") body_json = json.loads(body_text) assert body_json.get("detail") == expected_detail @@ -93,7 +99,9 @@ def test_auth_middleware_no_db() -> None: pytest.fail("call_next should not be called for invalid token") response = asyncio.run(middleware.dispatch(request, call_next)) - verify_json_response(response, status.HTTP_500_INTERNAL_SERVER_ERROR, "db factory not configured") + verify_json_response( + response, status.HTTP_500_INTERNAL_SERVER_ERROR, "db factory not configured" + ) def test_auth_middleware_rejects_invalid_token() -> None: @@ -114,7 +122,9 @@ def test_auth_middleware_rejects_invalid_token() -> None: async def call_next(req: Request) -> Response: # noqa: ARG001 pytest.fail("call_next should not be called for invalid token") - with patch("trading_journal.crud.get_login_session_by_token_hash", return_value=None): + with patch( + "trading_journal.crud.get_login_session_by_token_hash", return_value=None + ): response = asyncio.run(middleware.dispatch(request, call_next)) verify_json_response(response, status.HTTP_401_UNAUTHORIZED, "Unauthorized") @@ -147,8 +157,14 @@ def test_auth_middleware_rejects_expired_token() -> None: ) with ( - patch("trading_journal.security.hash_session_token_sha256", return_value=expired_session.session_token_hash) as mock_hash, - patch("trading_journal.crud.get_login_session_by_token_hash", return_value=expired_session), + patch( + "trading_journal.security.hash_session_token_sha256", + return_value=expired_session.session_token_hash, + ) as mock_hash, + patch( + "trading_journal.crud.get_login_session_by_token_hash", + return_value=expired_session, + ), patch("trading_journal.crud.delete_login_session") as mock_delete, ): response = asyncio.run(middleware.dispatch(request, call_next)) @@ -192,8 +208,14 @@ def test_auth_middleware_reject_inactive_user() -> None: ) with ( - patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), - patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), + patch( + "trading_journal.security.hash_session_token_sha256", + return_value=valid_session.session_token_hash, + ), + patch( + "trading_journal.crud.get_login_session_by_token_hash", + return_value=valid_session, + ), ): response = asyncio.run(middleware.dispatch(request, call_next)) @@ -210,7 +232,10 @@ def test_auth_middleware_allows_valid_token_and_no_update_expires() -> None: "type": "http", "method": "GET", "path": f"/{settings.api_base}/protected", - "headers": [(b"cookie", f"session_token={fake_token_orig}".encode()), (b"user-agent", b"test-agent")], + "headers": [ + (b"cookie", f"session_token={fake_token_orig}".encode()), + (b"user-agent", b"test-agent"), + ], "client": ("testclient", 50000), "app": app, } @@ -236,8 +261,14 @@ def test_auth_middleware_allows_valid_token_and_no_update_expires() -> None: ) with ( - patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), - patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), + patch( + "trading_journal.security.hash_session_token_sha256", + return_value=valid_session.session_token_hash, + ), + patch( + "trading_journal.crud.get_login_session_by_token_hash", + return_value=valid_session, + ), patch("trading_journal.crud.update_login_session") as mock_update, ): response = asyncio.run(middleware.dispatch(request, call_next)) @@ -259,7 +290,10 @@ def test_auth_middleware_allows_valid_token_and_updates_expires() -> None: "type": "http", "method": "GET", "path": f"/{settings.api_base}/protected", - "headers": [(b"cookie", f"session_token={fake_token_orig}".encode()), (b"user-agent", b"test-agent")], + "headers": [ + (b"cookie", f"session_token={fake_token_orig}".encode()), + (b"user-agent", b"test-agent"), + ], "client": ("testclient", 50000), "app": app, } @@ -285,8 +319,14 @@ def test_auth_middleware_allows_valid_token_and_updates_expires() -> None: ) with ( - patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), - patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), + patch( + "trading_journal.security.hash_session_token_sha256", + return_value=valid_session.session_token_hash, + ), + patch( + "trading_journal.crud.get_login_session_by_token_hash", + return_value=valid_session, + ), patch("trading_journal.crud.update_login_session") as mock_update, ): response = asyncio.run(middleware.dispatch(request, call_next)) @@ -295,8 +335,12 @@ def test_auth_middleware_allows_valid_token_and_updates_expires() -> None: _, kwargs = mock_update.call_args update_session = kwargs.get("update_session") assert update_session is not None - assert (update_session.expires_at - datetime.now(timezone.utc)).total_seconds() > settings.session_expiry_seconds - 1 - assert (update_session.last_seen_at - datetime.now(timezone.utc)).total_seconds() < 1 + assert ( + update_session.expires_at - datetime.now(timezone.utc) + ).total_seconds() > settings.session_expiry_seconds - 1 + assert ( + update_session.last_seen_at - datetime.now(timezone.utc) + ).total_seconds() < 1 assert update_session.last_used_ip == "testclient" assert update_session.user_agent == "test-agent" @@ -310,12 +354,19 @@ def test_register_user_success() -> None: } with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_user_by_username", return_value=None) as mock_get, + patch( + "trading_journal.crud.get_user_by_username", return_value=None + ) as mock_get, patch( "trading_journal.crud.create_user", - return_value=SimpleNamespace(id=1, username=user_in.username, is_active=True), + return_value=SimpleNamespace( + id=1, username=user_in.username, is_active=True + ), ) as mock_create, - patch("trading_journal.security.hash_password", return_value=user_in_with_hashed_password["password_hash"]), + patch( + "trading_journal.security.hash_password", + return_value=user_in_with_hashed_password["password_hash"], + ), ): user_out = service.register_user_service(db, user_in) assert user_out.id is not None @@ -330,7 +381,9 @@ def test_register_user_exists_raises() -> None: FakeDBFactory().get_session_ctx_manager() as db, patch( "trading_journal.crud.get_user_by_username", - return_value=SimpleNamespace(id=1, username=user_in.username, is_active=True), + return_value=SimpleNamespace( + id=1, username=user_in.username, is_active=True + ), ) as mock_get, ): with pytest.raises(service.UserAlreadyExistsError) as exc_info: @@ -341,10 +394,13 @@ def test_register_user_exists_raises() -> None: def test_authenticate_user_success() -> None: user_in = dto.UserLogin(username="validuser", password="validpassword") - stored_user = SimpleNamespace(id=1, username=user_in.username, is_active=True, password_hash="hashedpassword") + stored_user = SimpleNamespace( + id=1, username=user_in.username, is_active=True, password_hash="hashedpassword" + ) expected_login_session = dto.SessionsCreate( user_id=stored_user.id, - expires_at=datetime.now(timezone.utc) + timedelta(seconds=settings.session_expiry_seconds), + expires_at=datetime.now(timezone.utc) + + timedelta(seconds=settings.session_expiry_seconds), ) with ( FakeDBFactory().get_session_ctx_manager() as db, @@ -352,12 +408,22 @@ def test_authenticate_user_success() -> None: "trading_journal.crud.get_user_by_username", return_value=stored_user, ) as mock_get, - patch("trading_journal.security.verify_password", return_value=True) as mock_verify, - patch("trading_journal.security.generate_session_token", return_value="newsessiontoken") as mock_token, - patch("trading_journal.security.hash_session_token_sha256", return_value="newsessiontokenhash") as mock_hash_session_token, + patch( + "trading_journal.security.verify_password", return_value=True + ) as mock_verify, + patch( + "trading_journal.security.generate_session_token", + return_value="newsessiontoken", + ) as mock_token, + patch( + "trading_journal.security.hash_session_token_sha256", + return_value="newsessiontokenhash", + ) as mock_hash_session_token, patch( "trading_journal.crud.create_login_session", - return_value=SimpleNamespace(user_id=stored_user.id, expires_at=expected_login_session.expires_at), + return_value=SimpleNamespace( + user_id=stored_user.id, expires_at=expected_login_session.expires_at + ), ) as mock_create_session, ): user_out = service.authenticate_user_service(db, user_in) @@ -366,7 +432,14 @@ def test_authenticate_user_success() -> None: # assert fields instead of direct equality to avoid pydantic/model issues assert getattr(login_session, "user_id", None) == stored_user.id assert isinstance(getattr(login_session, "expires_at", None), datetime) - assert abs((login_session.expires_at - expected_login_session.expires_at).total_seconds()) < 2 + assert ( + abs( + ( + login_session.expires_at - expected_login_session.expires_at + ).total_seconds() + ) + < 2 + ) assert token == "newsessiontoken" assert login_session.user_id == stored_user.id mock_get.assert_called_once_with(db, user_in.username) @@ -397,14 +470,18 @@ def test_authenticate_user_not_found_returns_none() -> None: def test_authenticate_user_invalid_password_returns_none() -> None: user_in = dto.UserLogin(username="validuser", password="invalidpassword") - stored_user = SimpleNamespace(id=1, username=user_in.username, is_active=True, password_hash="hashedpassword") + stored_user = SimpleNamespace( + id=1, username=user_in.username, is_active=True, password_hash="hashedpassword" + ) with ( FakeDBFactory().get_session_ctx_manager() as db, patch( "trading_journal.crud.get_user_by_username", return_value=stored_user, ) as mock_get, - patch("trading_journal.security.verify_password", return_value=False) as mock_verify, + patch( + "trading_journal.security.verify_password", return_value=False + ) as mock_verify, ): user_out = service.authenticate_user_service(db, user_in) assert user_out is None @@ -419,12 +496,22 @@ def test_create_exchange_duplicate_raises() -> None: FakeDBFactory().get_session_ctx_manager() as db, patch( "trading_journal.crud.get_exchange_by_name_and_user_id", - return_value=SimpleNamespace(id=1, user_id=1, name=exchange_in.name, notes="Existing exchange"), + return_value=SimpleNamespace( + id=1, user_id=1, name=exchange_in.name, notes="Existing exchange" + ), ) as mock_get, ): with pytest.raises(service.ExchangeAlreadyExistsError) as exc_info: - service.create_exchange_service(db, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes) - assert str(exc_info.value) == "Exchange with the same name already exists for this user" + service.create_exchange_service( + db, + user_id=exchange_in.user_id, + name=exchange_in.name, + notes=exchange_in.notes, + ) + assert ( + str(exc_info.value) + == "Exchange with the same name already exists for this user" + ) mock_get.assert_called_once_with(db, exchange_in.name, exchange_in.user_id) @@ -438,10 +525,20 @@ def test_create_exchange_success() -> None: ) as mock_get, patch( "trading_journal.crud.create_exchange", - return_value=SimpleNamespace(id=2, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes), + return_value=SimpleNamespace( + id=2, + user_id=exchange_in.user_id, + name=exchange_in.name, + notes=exchange_in.notes, + ), ) as mock_create, ): - exchange_out = service.create_exchange_service(db, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes) + exchange_out = service.create_exchange_service( + db, + user_id=exchange_in.user_id, + name=exchange_in.name, + notes=exchange_in.notes, + ) assert exchange_out.name == exchange_in.name assert exchange_out.notes == exchange_in.notes mock_get.assert_called_once_with(db, exchange_in.name, exchange_in.user_id) @@ -455,7 +552,9 @@ def test_get_exchanges_by_user_id() -> None: "trading_journal.crud.get_all_exchanges_by_user_id", return_value=[ SimpleNamespace(id=1, user_id=1, name="NYSE", notes="First exchange"), - SimpleNamespace(id=2, user_id=1, name="NASDAQ", notes="Second exchange"), + SimpleNamespace( + id=2, user_id=1, name="NASDAQ", notes="Second exchange" + ), ], ) as mock_get, ): @@ -489,14 +588,22 @@ def test_update_exchange_not_found() -> None: ) as mock_get, ): with pytest.raises(service.ExchangeNotFoundError) as exc_info: - service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + service.update_exchanges_service( + db, + exchange_id=1, + user_id=1, + name=exchange_update.name, + notes=exchange_update.notes, + ) assert str(exc_info.value) == "Exchange not found" mock_get.assert_called_once_with(db, 1) def test_update_exchange_owner_mismatch_raises() -> None: exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") - existing_exchange = SimpleNamespace(id=1, user_id=2, name="OldName", notes="Old notes") + existing_exchange = SimpleNamespace( + id=1, user_id=2, name="OldName", notes="Old notes" + ) with ( FakeDBFactory().get_session_ctx_manager() as db, patch( @@ -505,14 +612,22 @@ def test_update_exchange_owner_mismatch_raises() -> None: ) as mock_get, ): with pytest.raises(service.ExchangeNotFoundError) as exc_info: - service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + service.update_exchanges_service( + db, + exchange_id=1, + user_id=1, + name=exchange_update.name, + notes=exchange_update.notes, + ) assert str(exc_info.value) == "Exchange not found" mock_get.assert_called_once_with(db, 1) def test_update_exchange_duplication() -> None: exchange_update = dto.ExchangesBase(name="DuplicateName", notes="Updated notes") - existing_exchange = SimpleNamespace(id=1, user_id=1, name="OldName", notes="Old notes") + existing_exchange = SimpleNamespace( + id=1, user_id=1, name="OldName", notes="Old notes" + ) with ( FakeDBFactory().get_session_ctx_manager() as db, patch( @@ -521,19 +636,32 @@ def test_update_exchange_duplication() -> None: ) as mock_get, patch( "trading_journal.crud.get_exchange_by_name_and_user_id", - return_value=SimpleNamespace(id=2, user_id=1, name="DuplicateName", notes="Another exchange"), + return_value=SimpleNamespace( + id=2, user_id=1, name="DuplicateName", notes="Another exchange" + ), ) as mock_get_by_name, ): with pytest.raises(service.ExchangeAlreadyExistsError) as exc_info: - service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) - assert str(exc_info.value) == "Another exchange with the same name already exists for this user" + service.update_exchanges_service( + db, + exchange_id=1, + user_id=1, + name=exchange_update.name, + notes=exchange_update.notes, + ) + assert ( + str(exc_info.value) + == "Another exchange with the same name already exists for this user" + ) mock_get.assert_called_once_with(db, 1) mock_get_by_name.assert_called_once_with(db, "DuplicateName", 1) def test_update_exchange_success() -> None: exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") - existing_exchange = SimpleNamespace(id=1, user_id=1, name="OldName", notes="Old notes") + existing_exchange = SimpleNamespace( + id=1, user_id=1, name="OldName", notes="Old notes" + ) with ( FakeDBFactory().get_session_ctx_manager() as db, patch( @@ -546,10 +674,18 @@ def test_update_exchange_success() -> None: ) as mock_get_by_name, patch( "trading_journal.crud.update_exchange", - return_value=SimpleNamespace(id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes), + return_value=SimpleNamespace( + id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes + ), ) as mock_update, ): - exchange_out = service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + exchange_out = service.update_exchanges_service( + db, + exchange_id=1, + user_id=1, + name=exchange_update.name, + notes=exchange_update.notes, + ) assert exchange_out.name == exchange_update.name assert exchange_out.notes == exchange_update.notes mock_get.assert_called_once_with(db, 1) @@ -605,7 +741,9 @@ def test_get_cycle_by_id_success() -> None: FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_cycle_by_id", return_value=cycle) as mock_get, ): - cycle_out = service.get_cycle_by_id_service(db, user_id=user_id, cycle_id=cycle_id) + cycle_out = service.get_cycle_by_id_service( + db, user_id=user_id, cycle_id=cycle_id + ) assert cycle_out.id == cycle_id assert cycle_out.user_id == user_id assert cycle_out.friendly_name == "Test Cycle" @@ -622,7 +760,9 @@ def test_get_cycles_by_user_no_cycles() -> None: user_id = 1 with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_cycles_by_user_id", return_value=[]) as mock_get, + patch( + "trading_journal.crud.get_cycles_by_user_id", return_value=[] + ) as mock_get, ): cycles = service.get_cycles_by_user_service(db, user_id=user_id) assert isinstance(cycles, list) @@ -660,7 +800,9 @@ def test_get_cycles_by_user_with_cycles() -> None: ) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_cycles_by_user_id", return_value=[cycle1, cycle2]) as mock_get, + patch( + "trading_journal.crud.get_cycles_by_user_id", return_value=[cycle1, cycle2] + ) as mock_get, ): cycles = service.get_cycles_by_user_service(db, user_id=user_id) assert isinstance(cycles, list) @@ -683,7 +825,12 @@ def test_update_cycle_closed_status_mismatch_raises() -> None: def test_update_cycle_open_status_mismatch_raises() -> None: - cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", end_date=datetime.now(timezone.utc).date()) + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + end_date=datetime.now(timezone.utc).date(), + ) with ( FakeDBFactory().get_session_ctx_manager() as db, ): @@ -693,7 +840,9 @@ def test_update_cycle_open_status_mismatch_raises() -> None: def test_update_cycle_invalid_capital_exposure_raises() -> None: - cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", capital_exposure_cents=-100) + cycle_data = dto.CycleUpdate( + id=1, friendly_name="Updated Cycle", status="OPEN", capital_exposure_cents=-100 + ) with ( FakeDBFactory().get_session_ctx_manager() as db, ): @@ -703,13 +852,22 @@ def test_update_cycle_invalid_capital_exposure_raises() -> None: def test_update_cycle_no_cash_no_loan_raises() -> None: - cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", funding_source="LOAN", loan_amount_cents=None) + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + funding_source="LOAN", + loan_amount_cents=None, + ) with ( FakeDBFactory().get_session_ctx_manager() as db, ): with pytest.raises(service.InvalidCycleDataError) as exc_info: service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) - assert str(exc_info.value) == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + assert ( + str(exc_info.value) + == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + ) def test_update_cycle_loan_missing_interest_raises() -> None: @@ -725,7 +883,10 @@ def test_update_cycle_loan_missing_interest_raises() -> None: ): with pytest.raises(service.InvalidCycleDataError) as exc_info: service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) - assert str(exc_info.value) == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + assert ( + str(exc_info.value) + == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + ) def test_update_cycle_loan_negative_loan_raises() -> None: @@ -759,7 +920,9 @@ def test_update_cycle_loan_negative_interest_raises() -> None: ): with pytest.raises(service.InvalidCycleDataError) as exc_info: service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) - assert str(exc_info.value) == "loan_interest_rate_tenth_bps must be non-negative" + assert ( + str(exc_info.value) == "loan_interest_rate_tenth_bps must be non-negative" + ) def test_update_cycle_not_found_raises() -> None: @@ -779,7 +942,9 @@ def test_update_cycle_owner_mismatch_raises() -> None: existing_cycle = SimpleNamespace(id=1, user_id=2) # Owned by different user with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_cycle_by_id", return_value=existing_cycle) as mock_get, + patch( + "trading_journal.crud.get_cycle_by_id", return_value=existing_cycle + ) as mock_get, ): with pytest.raises(service.CycleNotFoundError) as exc_info: service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) @@ -788,7 +953,16 @@ def test_update_cycle_owner_mismatch_raises() -> None: def test_update_cycle_success() -> None: - cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", funding_source="CASH", capital_exposure_cents=5000) + today = datetime.now(timezone.utc).date() + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + funding_source="MIXED", + capital_exposure_cents=5000, + loan_amount_cents=2000, + loan_interest_rate_tenth_bps=50, + ) existing_cycle = SimpleNamespace( id=1, user_id=1, @@ -796,7 +970,7 @@ def test_update_cycle_success() -> None: symbol="AAPL", exchange_id=1, underlying_currency="USD", - start_date=datetime.now(timezone.utc).date(), + start_date=today, status="OPEN", funding_source="MIXED", capital_exposure_cents=10000, @@ -814,13 +988,21 @@ def test_update_cycle_success() -> None: status=cycle_data.status, funding_source=cycle_data.funding_source, capital_exposure_cents=cycle_data.capital_exposure_cents, - loan_amount_cents=None, - loan_interest_rate_tenth_bps=None, + loan_amount_cents=cycle_data.loan_amount_cents, + loan_interest_rate_tenth_bps=cycle_data.loan_interest_rate_tenth_bps, ) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_cycle_by_id", return_value=existing_cycle) as mock_get, - patch("trading_journal.crud.update_cycle", return_value=updated_cycle) as mock_update, + patch( + "trading_journal.crud.get_cycle_by_id", return_value=existing_cycle + ) as mock_get, + patch( + "trading_journal.crud.update_cycle", return_value=updated_cycle + ) as mock_update, + patch( + "trading_journal.crud.get_loan_event_by_cycle_id_and_effective_date", + return_value=None, + ) as mock_get_loan_event, ): cycle_out = service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) assert cycle_out.id == updated_cycle.id @@ -828,8 +1010,11 @@ def test_update_cycle_success() -> None: assert cycle_out.status == updated_cycle.status assert cycle_out.funding_source == updated_cycle.funding_source assert cycle_out.capital_exposure_cents == updated_cycle.capital_exposure_cents - assert cycle_out.loan_amount_cents is None - assert cycle_out.loan_interest_rate_tenth_bps is None + assert cycle_out.loan_amount_cents == updated_cycle.loan_amount_cents + assert ( + cycle_out.loan_interest_rate_tenth_bps + == updated_cycle.loan_interest_rate_tenth_bps + ) mock_get.assert_called_once_with(db, cycle_data.id) update_cycle_base = dto.CycleBase( friendly_name=cycle_data.friendly_name, @@ -837,10 +1022,15 @@ def test_update_cycle_success() -> None: funding_source=cycle_data.funding_source, capital_exposure_cents=cycle_data.capital_exposure_cents, loan_amount_cents=getattr(cycle_data, "loan_amount_cents", None), - loan_interest_rate_tenth_bps=getattr(cycle_data, "loan_interest_rate_tenth_bps", None), + loan_interest_rate_tenth_bps=getattr( + cycle_data, "loan_interest_rate_tenth_bps", None + ), end_date=getattr(cycle_data, "end_date", None), ) - mock_update.assert_called_once_with(db, cycle_data.id, update_data=update_cycle_base) + mock_update.assert_called_once_with( + db, cycle_data.id, update_data=update_cycle_base + ) + mock_get_loan_event.assert_called_once_with(db, cycle_data.id, today) # --- Trade services ---------------------------------------------------------- @@ -868,7 +1058,10 @@ def test_create_trade_short_option_no_strike() -> None: ): with pytest.raises(service.InvalidTradeDataError) as exc_info: service.create_trade_service(db, 1, trade_data) - assert str(exc_info.value) == "Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades" + assert ( + str(exc_info.value) + == "Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades" + ) def test_create_trade_success() -> None: @@ -911,7 +1104,9 @@ def test_create_trade_success() -> None: ) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.create_trade", return_value=created_trade) as mock_create_trade, + patch( + "trading_journal.crud.create_trade", return_value=created_trade + ) as mock_create_trade, ): trade_out = service.create_trade_service(db, user_id=1, trade_data=trade_data) assert trade_out.id == created_trade.id @@ -920,17 +1115,28 @@ def test_create_trade_success() -> None: assert trade_out.trade_type == created_trade.trade_type mock_create_trade.assert_called_once() _, kwargs = mock_create_trade.call_args - passed_trade = kwargs.get("trade_data") or (mock_create_trade.call_args[0][1] if len(mock_create_trade.call_args[0]) > 1 else None) + passed_trade = kwargs.get("trade_data") or ( + mock_create_trade.call_args[0][1] + if len(mock_create_trade.call_args[0]) > 1 + else None + ) assert passed_trade is not None # expected for SELL_PUT: gross = quantity * price * quantity_multiplier (positive), net = gross - commission - expected_gross = trade_data.quantity * trade_data.price_cents * (trade_data.quantity_multiplier or 1) + expected_gross = ( + trade_data.quantity + * trade_data.price_cents + * (trade_data.quantity_multiplier or 1) + ) expected_net = expected_gross - trade_data.commission_cents assert getattr(passed_trade, "gross_cash_flow_cents", None) == expected_gross assert getattr(passed_trade, "net_cash_flow_cents", None) == expected_net def test_get_trade_by_id_not_found_when_missing() -> None: - with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get, + ): with pytest.raises(service.TradeNotFoundError) as exc_info: service.get_trade_by_id_service(db, user_id=1, trade_id=1) assert str(exc_info.value) == "Trade not found" @@ -941,7 +1147,9 @@ def test_get_trade_by_id_not_found_owner_mismatch() -> None: existing_trade = SimpleNamespace(id=2, user_id=2) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + patch( + "trading_journal.crud.get_trade_by_id", return_value=existing_trade + ) as mock_get, ): with pytest.raises(service.TradeNotFoundError) as exc_info: service.get_trade_by_id_service(db, user_id=1, trade_id=2) @@ -976,7 +1184,12 @@ def test_get_trade_by_id_success() -> None: notes="ok", cycle_id=None, ) - with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=trade_obj) as mock_get: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_trade_by_id", return_value=trade_obj + ) as mock_get, + ): res = service.get_trade_by_id_service(db, user_id=1, trade_id=10) assert res.id == trade_obj.id assert res.user_id == trade_obj.user_id @@ -986,9 +1199,14 @@ def test_get_trade_by_id_success() -> None: def test_update_trade_friendly_name_not_found() -> None: - with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get, + ): with pytest.raises(service.TradeNotFoundError) as exc_info: - service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Name") + service.update_trade_friendly_name_service( + db, user_id=1, trade_id=10, friendly_name="New Name" + ) assert str(exc_info.value) == "Trade not found" mock_get.assert_called_once_with(db, 10) @@ -997,10 +1215,14 @@ def test_update_trade_friendly_name_owner_mismatch_raises() -> None: existing_trade = SimpleNamespace(id=10, user_id=2) # owned by another user with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + patch( + "trading_journal.crud.get_trade_by_id", return_value=existing_trade + ) as mock_get, ): with pytest.raises(service.TradeNotFoundError) as exc_info: - service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Name") + service.update_trade_friendly_name_service( + db, user_id=1, trade_id=10, friendly_name="New Name" + ) assert str(exc_info.value) == "Trade not found" mock_get.assert_called_once_with(db, 10) @@ -1031,20 +1253,32 @@ def test_update_trade_friendly_name_success() -> None: notes="ok", cycle_id=None, ) - updated_trade = SimpleNamespace(**{**existing_trade.__dict__, "friendly_name": "New Friendly"}) + updated_trade = SimpleNamespace( + **{**existing_trade.__dict__, "friendly_name": "New Friendly"} + ) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, - patch("trading_journal.crud.update_trade_friendly_name", return_value=updated_trade) as mock_update, + patch( + "trading_journal.crud.get_trade_by_id", return_value=existing_trade + ) as mock_get, + patch( + "trading_journal.crud.update_trade_friendly_name", + return_value=updated_trade, + ) as mock_update, ): - res = service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Friendly") + res = service.update_trade_friendly_name_service( + db, user_id=1, trade_id=10, friendly_name="New Friendly" + ) assert res.friendly_name == "New Friendly" mock_get.assert_called_once_with(db, 10) mock_update.assert_called_once_with(db, 10, "New Friendly") def test_update_trade_note_not_found() -> None: - with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get, + ): with pytest.raises(service.TradeNotFoundError) as exc_info: service.update_trade_note_service(db, user_id=1, trade_id=20, note="x") assert str(exc_info.value) == "Trade not found" @@ -1055,7 +1289,9 @@ def test_update_trade_note_owner_mismatch_raises() -> None: existing_trade = SimpleNamespace(id=20, user_id=2) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + patch( + "trading_journal.crud.get_trade_by_id", return_value=existing_trade + ) as mock_get, ): with pytest.raises(service.TradeNotFoundError) as exc_info: service.update_trade_note_service(db, user_id=1, trade_id=20, note="x") @@ -1092,8 +1328,12 @@ def test_update_trade_note_success_and_none_becomes_empty() -> None: updated_trade = SimpleNamespace(**{**existing_trade.__dict__, "notes": ""}) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, - patch("trading_journal.crud.update_trade_note", return_value=updated_trade) as mock_update, + patch( + "trading_journal.crud.get_trade_by_id", return_value=existing_trade + ) as mock_get, + patch( + "trading_journal.crud.update_trade_note", return_value=updated_trade + ) as mock_update, ): res = service.update_trade_note_service(db, user_id=1, trade_id=20, note=None) assert res.notes == "" diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 66e2c4f..b771433 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -52,7 +52,9 @@ def _data_to_dict(data: AnyModel) -> dict[str, AnyModel]: # Trades -def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> models.Trades: +def create_trade( + session: Session, trade_data: Mapping[str, Any] | BaseModel +) -> models.Trades: data = _data_to_dict(trade_data) allowed = _allowed_columns(models.Trades) payload = {k: v for k, v in data.items() if k in allowed} @@ -72,13 +74,19 @@ def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> raise ValueError("exchange.user_id does not match trade.user_id") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") - payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") + payload["underlying_currency"] = _check_enum( + models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency" + ) if "trade_type" not in payload: raise ValueError("trade_type is required") - payload["trade_type"] = _check_enum(models.TradeType, payload["trade_type"], "trade_type") + payload["trade_type"] = _check_enum( + models.TradeType, payload["trade_type"], "trade_type" + ) if "trade_strategy" not in payload: raise ValueError("trade_strategy is required") - payload["trade_strategy"] = _check_enum(models.TradeStrategy, payload["trade_strategy"], "trade_strategy") + payload["trade_strategy"] = _check_enum( + models.TradeStrategy, payload["trade_strategy"], "trade_strategy" + ) # trade_time_utc is the creation moment: always set to now (caller shouldn't provide) now = datetime.now(timezone.utc) payload.pop("trade_time_utc", None) @@ -105,7 +113,8 @@ def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> "symbol": payload["symbol"], "exchange_id": payload["exchange_id"], "underlying_currency": payload["underlying_currency"], - "friendly_name": "Auto-created Cycle by trade " + payload.get("friendly_name", ""), + "friendly_name": "Auto-created Cycle by trade " + + payload.get("friendly_name", ""), "status": models.CycleStatus.OPEN, "start_date": payload["trade_date"], } @@ -119,7 +128,9 @@ def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> if cycle is None: raise ValueError("cycle_id does not exist") - payload.pop("exchange_id", None) # ignore exchange_id if provided; use cycle's exchange_id + payload.pop( + "exchange_id", None + ) # ignore exchange_id if provided; use cycle's exchange_id payload["exchange_id"] = cycle.exchange_id if cycle.user_id != user_id: raise ValueError("cycle.user_id does not match trade.user_id") @@ -147,7 +158,9 @@ def get_trade_by_id(session: Session, trade_id: int) -> models.Trades | None: return session.get(models.Trades, trade_id) -def get_trade_by_user_id_and_friendly_name(session: Session, user_id: int, friendly_name: str) -> models.Trades | None: +def get_trade_by_user_id_and_friendly_name( + session: Session, user_id: int, friendly_name: str +) -> models.Trades | None: statement = select(models.Trades).where( models.Trades.user_id == user_id, models.Trades.friendly_name == friendly_name, @@ -162,7 +175,9 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades] return list(session.exec(statement).all()) -def update_trade_friendly_name(session: Session, trade_id: int, friendly_name: str) -> models.Trades: +def update_trade_friendly_name( + session: Session, trade_id: int, friendly_name: str +) -> models.Trades: trade: models.Trades | None = session.get(models.Trades, trade_id) if trade is None: raise ValueError("trade_id does not exist") @@ -210,7 +225,9 @@ def invalidate_trade(session: Session, trade_id: int) -> models.Trades: return trade -def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel) -> models.Trades: +def replace_trade( + session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel +) -> models.Trades: invalidate_trade(session, old_trade_id) data = _data_to_dict(new_trade_data) data["replaced_by_trade_id"] = old_trade_id @@ -218,7 +235,9 @@ def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping[s # Cycles -def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> models.Cycles: +def create_cycle( + session: Session, cycle_data: Mapping[str, Any] | BaseModel +) -> models.Cycles: data = _data_to_dict(cycle_data) allowed = _allowed_columns(models.Cycles) payload = {k: v for k, v in data.items() if k in allowed} @@ -236,7 +255,9 @@ def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> raise ValueError("exchange.user_id does not match cycle.user_id") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") - payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") + payload["underlying_currency"] = _check_enum( + models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency" + ) if "status" not in payload: raise ValueError("status is required") payload["status"] = _check_enum(models.CycleStatus, payload["status"], "status") @@ -268,7 +289,9 @@ def get_cycles_by_user_id(session: Session, user_id: int) -> list[models.Cycles] return list(session.exec(statement).all()) -def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles: +def update_cycle( + session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel +) -> models.Cycles: cycle: models.Cycles | None = session.get(models.Cycles, cycle_id) if cycle is None: raise ValueError("cycle_id does not exist") @@ -304,7 +327,9 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] # Cycle loan and interest -def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents: +def create_cycle_loan_event( + session: Session, loan_data: Mapping[str, Any] | BaseModel +) -> models.CycleLoanChangeEvents: data = _data_to_dict(loan_data) allowed = _allowed_columns(models.CycleLoanChangeEvents) payload = {k: v for k, v in data.items() if k in allowed} @@ -314,7 +339,9 @@ def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | Bas if cycle is None: raise ValueError("cycle_id does not exist") - payload["effective_date"] = payload.get("effective_date") or datetime.now(timezone.utc).date() + payload["effective_date"] = ( + payload.get("effective_date") or datetime.now(timezone.utc).date() + ) payload["created_at"] = datetime.now(timezone.utc) cle = models.CycleLoanChangeEvents(**payload) session.add(cle) @@ -327,7 +354,9 @@ def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | Bas return cle -def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleLoanChangeEvents]: +def get_loan_events_by_cycle_id( + session: Session, cycle_id: int +) -> list[models.CycleLoanChangeEvents]: eff_col = cast("ColumnElement", models.CycleLoanChangeEvents.effective_date) id_col = cast("ColumnElement", models.CycleLoanChangeEvents.id) statement = ( @@ -340,7 +369,46 @@ def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models. return list(session.exec(statement).all()) -def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int) -> models.CycleDailyAccrual: +def get_loan_event_by_cycle_id_and_effective_date( + session: Session, cycle_id: int, effective_date: date +) -> models.CycleLoanChangeEvents | None: + statement = select(models.CycleLoanChangeEvents).where( + models.CycleLoanChangeEvents.cycle_id == cycle_id, + models.CycleLoanChangeEvents.effective_date == effective_date, + ) + return session.exec(statement).first() + + +def update_cycle_loan_event( + session: Session, event_id: int, update_data: Mapping[str, Any] | BaseModel +) -> models.CycleLoanChangeEvents: + event: models.CycleLoanChangeEvents | None = session.get( + models.CycleLoanChangeEvents, event_id + ) + if event is None: + raise ValueError("event_id does not exist") + data = _data_to_dict(update_data) + + allowed = _allowed_columns(models.CycleLoanChangeEvents) + for k, v in data.items(): + if k in {"id", "cycle_id", "effective_date", "created_at"}: + raise ValueError(f"field {k!r} is immutable") + if k not in allowed: + continue + setattr(event, k, v) + session.add(event) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_cycle_loan_event integrity error") from e + session.refresh(event) + return event + + +def create_cycle_daily_accrual( + session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int +) -> models.CycleDailyAccrual: cycle = session.get(models.Cycles, cycle_id) if cycle is None: raise ValueError("cycle_id does not exist") @@ -370,7 +438,9 @@ def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: da return row -def get_cycle_daily_accruals_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleDailyAccrual]: +def get_cycle_daily_accruals_by_cycle_id( + session: Session, cycle_id: int +) -> list[models.CycleDailyAccrual]: date_col = cast("ColumnElement", models.CycleDailyAccrual.accrual_date) statement = ( select(models.CycleDailyAccrual) @@ -382,7 +452,9 @@ def get_cycle_daily_accruals_by_cycle_id(session: Session, cycle_id: int) -> lis return list(session.exec(statement).all()) -def get_cycle_daily_accrual_by_cycle_id_and_date(session: Session, cycle_id: int, accrual_date: date) -> models.CycleDailyAccrual | None: +def get_cycle_daily_accrual_by_cycle_id_and_date( + session: Session, cycle_id: int, accrual_date: date +) -> models.CycleDailyAccrual | None: statement = select(models.CycleDailyAccrual).where( models.CycleDailyAccrual.cycle_id == cycle_id, models.CycleDailyAccrual.accrual_date == accrual_date, @@ -394,7 +466,9 @@ def get_cycle_daily_accrual_by_cycle_id_and_date(session: Session, cycle_id: int IMMUTABLE_EXCHANGE_FIELDS = {"id"} -def create_exchange(session: Session, exchange_data: Mapping[str, Any] | BaseModel) -> models.Exchanges: +def create_exchange( + session: Session, exchange_data: Mapping[str, Any] | BaseModel +) -> models.Exchanges: data = _data_to_dict(exchange_data) allowed = _allowed_columns(models.Exchanges) payload = {k: v for k, v in data.items() if k in allowed} @@ -416,7 +490,9 @@ def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | return session.get(models.Exchanges, exchange_id) -def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) -> models.Exchanges | None: +def get_exchange_by_name_and_user_id( + session: Session, name: str, user_id: int +) -> models.Exchanges | None: statement = select(models.Exchanges).where( models.Exchanges.name == name, models.Exchanges.user_id == user_id, @@ -429,14 +505,18 @@ def get_all_exchanges(session: Session) -> list[models.Exchanges]: return list(session.exec(statement).all()) -def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]: +def get_all_exchanges_by_user_id( + session: Session, user_id: int +) -> list[models.Exchanges]: statement = select(models.Exchanges).where( models.Exchanges.user_id == user_id, ) return list(session.exec(statement).all()) -def update_exchange(session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Exchanges: +def update_exchange( + session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel +) -> models.Exchanges: exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) if exchange is None: raise ValueError("exchange_id does not exist") @@ -473,7 +553,9 @@ def delete_exchange(session: Session, exchange_id: int) -> None: IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"} -def create_user(session: Session, user_data: Mapping[str, Any] | BaseModel) -> models.Users: +def create_user( + session: Session, user_data: Mapping[str, Any] | BaseModel +) -> models.Users: data = _data_to_dict(user_data) allowed = _allowed_columns(models.Users) payload = {k: v for k, v in data.items() if k in allowed} @@ -504,7 +586,9 @@ def get_user_by_username(session: Session, username: str) -> models.Users | None return session.exec(statement).first() -def update_user(session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Users: +def update_user( + session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel +) -> models.Users: user: models.Users | None = session.get(models.Users, user_id) if user is None: raise ValueError("user_id does not exist") @@ -561,7 +645,9 @@ def create_login_session( return s -def get_login_session_by_token_hash_and_user_id(session: Session, session_token_hash: str, user_id: int) -> models.Sessions | None: +def get_login_session_by_token_hash_and_user_id( + session: Session, session_token_hash: str, user_id: int +) -> models.Sessions | None: statement = select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hash, models.Sessions.user_id == user_id, @@ -571,7 +657,9 @@ def get_login_session_by_token_hash_and_user_id(session: Session, session_token_ return session.exec(statement).first() -def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None: +def get_login_session_by_token_hash( + session: Session, session_token_hash: str +) -> models.Sessions | None: statement = select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hash, models.Sessions.expires_at > datetime.now(timezone.utc), @@ -583,7 +671,11 @@ def get_login_session_by_token_hash(session: Session, session_token_hash: str) - IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} -def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping[str, Any] | BaseModel) -> models.Sessions | None: +def update_login_session( + session: Session, + session_token_hashed: str, + update_session: Mapping[str, Any] | BaseModel, +) -> models.Sessions | None: login_session: models.Sessions | None = session.exec( select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hashed, diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 0dd227f..3557108 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -1,11 +1,15 @@ from __future__ import annotations -from datetime import date, datetime # noqa: TC003 +from datetime import date, datetime from pydantic import BaseModel from sqlmodel import SQLModel -from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency # noqa: TC001 +from trading_journal.models import ( + TradeStrategy, + TradeType, + UnderlyingCurrency, +) class UserBase(SQLModel): @@ -90,6 +94,32 @@ class CycleRead(CycleCreate): id: int +class CycleLoanChangeEventBase(SQLModel): + cycle_id: int + effective_date: date + loan_amount_cents: int | None = None + loan_interest_rate_tenth_bps: int | None = None + related_trade_id: int | None = None + notes: str | None = None + created_at: datetime + + +class CycleLoanChangeEventCreate(CycleLoanChangeEventBase): + created_at: datetime + + +class CycleLoanChangeEventRead(CycleLoanChangeEventBase): + id: int + created_at: datetime + + +class CycleInterestAccrualBase(SQLModel): + cycle_id: int + accrual_date: date + accrual_amount_cents: int + created_at: datetime + + class TradeBase(SQLModel): friendly_name: str | None = None symbol: str diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index f060c26..41900f8 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -69,33 +69,51 @@ class FundingSource(str, Enum): class Trades(SQLModel, table=True): __tablename__ = "trades" # type: ignore[attr-defined] - __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),) + __table_args__ = ( + UniqueConstraint( + "user_id", "friendly_name", name="uq_trades_user_friendly_name" + ), + ) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint - friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + friendly_name: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) symbol: str = Field(sa_column=Column(Text, nullable=False)) exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True) exchange: "Exchanges" = Relationship(back_populates="trades") - underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: UnderlyingCurrency = Field( + sa_column=Column(Text, nullable=False) + ) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) trade_date: date = Field(sa_column=Column(Date, nullable=False)) - trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + trade_time_utc: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int = Field(sa_column=Column(Integer, nullable=False)) - quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1) + quantity_multiplier: int = Field( + sa_column=Column(Integer, nullable=False), default=1 + ) price_cents: int = Field(sa_column=Column(Integer, nullable=False)) gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) is_invalidated: bool = Field(default=False, nullable=False) - invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True)) - replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True) + invalidated_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True), nullable=True) + ) + replaced_by_trade_id: int | None = Field( + default=None, foreign_key="trades.id", nullable=True + ) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True) + cycle_id: int | None = Field( + default=None, foreign_key="cycles.id", nullable=True, index=True + ) cycle: "Cycles" = Relationship(back_populates="trades") @@ -107,15 +125,23 @@ class Trades(SQLModel, table=True): class Cycles(SQLModel, table=True): __tablename__ = "cycles" # type: ignore[attr-defined] - __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),) + __table_args__ = ( + UniqueConstraint( + "user_id", "friendly_name", name="uq_cycles_user_friendly_name" + ), + ) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) - friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + friendly_name: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) symbol: str = Field(sa_column=Column(Text, nullable=False)) exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True) exchange: "Exchanges" = Relationship(back_populates="cycles") - underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: UnderlyingCurrency = Field( + sa_column=Column(Text, nullable=False) + ) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) @@ -127,23 +153,51 @@ class Cycles(SQLModel, table=True): loan_amount_cents: int | None = Field(default=None, nullable=True) loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) - latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) - total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False)) + latest_interest_accrued_date: date | None = Field( + default=None, sa_column=Column(Date, nullable=True) + ) + total_accrued_amount_cents: int = Field( + default=0, sa_column=Column(Integer, nullable=False) + ) - loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle") + loan_change_events: list["CycleLoanChangeEvents"] = Relationship( + back_populates="cycle" + ) daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle") class CycleLoanChangeEvents(SQLModel, table=True): __tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined] + __table_args__ = ( + UniqueConstraint( + "cycle_id", "effective_date", name="uq_cycle_loan_change_cycle_date" + ), + ) + id: int | None = Field(default=None, primary_key=True) - cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True)) + cycle_id: int = Field( + sa_column=Column( + Integer, + ForeignKey("cycles.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + ) effective_date: date = Field(sa_column=Column(Date, nullable=False)) - loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) - loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) - related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True)) + loan_amount_cents: int | None = Field( + default=None, sa_column=Column(Integer, nullable=True) + ) + loan_interest_rate_tenth_bps: int | None = Field( + default=None, sa_column=Column(Integer, nullable=True) + ) + related_trade_id: int | None = Field( + default=None, + sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True), + ) # Not used for now. notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) cycle: "Cycles" = Relationship(back_populates="loan_change_events") trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event") @@ -151,20 +205,35 @@ class CycleLoanChangeEvents(SQLModel, table=True): class CycleDailyAccrual(SQLModel, table=True): __tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined] - __table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),) + __table_args__ = ( + UniqueConstraint( + "cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date" + ), + ) id: int | None = Field(default=None, primary_key=True) - cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True)) + cycle_id: int = Field( + sa_column=Column( + Integer, + ForeignKey("cycles.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + ) accrual_date: date = Field(sa_column=Column(Date, nullable=False)) accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False)) - created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) cycle: "Cycles" = Relationship(back_populates="daily_accruals") class Exchanges(SQLModel, table=True): __tablename__ = "exchanges" # type: ignore[attr-defined] - __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) + __table_args__ = ( + UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"), + ) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) name: str = Field(sa_column=Column(Text, nullable=False)) @@ -190,10 +259,18 @@ class Sessions(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) - created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) - expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True)) - last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True)) - last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) + expires_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False, index=True) + ) + last_seen_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True), nullable=True) + ) + last_used_ip: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user: "Users" = Relationship(back_populates="sessions") diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index f060c26..41900f8 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -69,33 +69,51 @@ class FundingSource(str, Enum): class Trades(SQLModel, table=True): __tablename__ = "trades" # type: ignore[attr-defined] - __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),) + __table_args__ = ( + UniqueConstraint( + "user_id", "friendly_name", name="uq_trades_user_friendly_name" + ), + ) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint - friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + friendly_name: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) symbol: str = Field(sa_column=Column(Text, nullable=False)) exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True) exchange: "Exchanges" = Relationship(back_populates="trades") - underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: UnderlyingCurrency = Field( + sa_column=Column(Text, nullable=False) + ) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) trade_date: date = Field(sa_column=Column(Date, nullable=False)) - trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + trade_time_utc: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int = Field(sa_column=Column(Integer, nullable=False)) - quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1) + quantity_multiplier: int = Field( + sa_column=Column(Integer, nullable=False), default=1 + ) price_cents: int = Field(sa_column=Column(Integer, nullable=False)) gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) is_invalidated: bool = Field(default=False, nullable=False) - invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True)) - replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True) + invalidated_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True), nullable=True) + ) + replaced_by_trade_id: int | None = Field( + default=None, foreign_key="trades.id", nullable=True + ) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True) + cycle_id: int | None = Field( + default=None, foreign_key="cycles.id", nullable=True, index=True + ) cycle: "Cycles" = Relationship(back_populates="trades") @@ -107,15 +125,23 @@ class Trades(SQLModel, table=True): class Cycles(SQLModel, table=True): __tablename__ = "cycles" # type: ignore[attr-defined] - __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),) + __table_args__ = ( + UniqueConstraint( + "user_id", "friendly_name", name="uq_cycles_user_friendly_name" + ), + ) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) - friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + friendly_name: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) symbol: str = Field(sa_column=Column(Text, nullable=False)) exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True) exchange: "Exchanges" = Relationship(back_populates="cycles") - underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: UnderlyingCurrency = Field( + sa_column=Column(Text, nullable=False) + ) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) @@ -127,23 +153,51 @@ class Cycles(SQLModel, table=True): loan_amount_cents: int | None = Field(default=None, nullable=True) loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) - latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) - total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False)) + latest_interest_accrued_date: date | None = Field( + default=None, sa_column=Column(Date, nullable=True) + ) + total_accrued_amount_cents: int = Field( + default=0, sa_column=Column(Integer, nullable=False) + ) - loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle") + loan_change_events: list["CycleLoanChangeEvents"] = Relationship( + back_populates="cycle" + ) daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle") class CycleLoanChangeEvents(SQLModel, table=True): __tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined] + __table_args__ = ( + UniqueConstraint( + "cycle_id", "effective_date", name="uq_cycle_loan_change_cycle_date" + ), + ) + id: int | None = Field(default=None, primary_key=True) - cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True)) + cycle_id: int = Field( + sa_column=Column( + Integer, + ForeignKey("cycles.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + ) effective_date: date = Field(sa_column=Column(Date, nullable=False)) - loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) - loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) - related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True)) + loan_amount_cents: int | None = Field( + default=None, sa_column=Column(Integer, nullable=True) + ) + loan_interest_rate_tenth_bps: int | None = Field( + default=None, sa_column=Column(Integer, nullable=True) + ) + related_trade_id: int | None = Field( + default=None, + sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True), + ) # Not used for now. notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) cycle: "Cycles" = Relationship(back_populates="loan_change_events") trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event") @@ -151,20 +205,35 @@ class CycleLoanChangeEvents(SQLModel, table=True): class CycleDailyAccrual(SQLModel, table=True): __tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined] - __table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),) + __table_args__ = ( + UniqueConstraint( + "cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date" + ), + ) id: int | None = Field(default=None, primary_key=True) - cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True)) + cycle_id: int = Field( + sa_column=Column( + Integer, + ForeignKey("cycles.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + ) accrual_date: date = Field(sa_column=Column(Date, nullable=False)) accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False)) - created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) cycle: "Cycles" = Relationship(back_populates="daily_accruals") class Exchanges(SQLModel, table=True): __tablename__ = "exchanges" # type: ignore[attr-defined] - __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) + __table_args__ = ( + UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"), + ) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) name: str = Field(sa_column=Column(Text, nullable=False)) @@ -190,10 +259,18 @@ class Sessions(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) - created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) - expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True)) - last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True)) - last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) + expires_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False, index=True) + ) + last_seen_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True), nullable=True) + ) + last_used_ip: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user: "Users" = Relationship(back_populates="sessions") diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 79d487e..c83bfb7 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -13,6 +13,7 @@ from trading_journal import crud, security from trading_journal.dto import ( CycleBase, CycleCreate, + CycleLoanChangeEventBase, CycleRead, CycleUpdate, ExchangesBase, @@ -26,6 +27,17 @@ from trading_journal.dto import ( UserLogin, UserRead, ) +from trading_journal.service_error import ( + CycleLoanEventExistsError, + CycleNotFoundError, + ExchangeAlreadyExistsError, + ExchangeNotFoundError, + InvalidCycleDataError, + InvalidTradeDataError, + ServiceError, + TradeNotFoundError, + UserAlreadyExistsError, +) if TYPE_CHECKING: from sqlmodel import Session @@ -44,7 +56,9 @@ logger = logging.getLogger(__name__) class AuthMiddleWare(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: PLR0911 + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: # noqa: PLR0911 if request.url.path in EXCEPT_PATHS: return await call_next(request) @@ -62,22 +76,44 @@ class AuthMiddleWare(BaseHTTPMiddleware): db_factory: Database | None = getattr(request.app.state, "db_factory", None) if db_factory is None: - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"}) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "db factory not configured"}, + ) try: with db_factory.get_session_ctx_manager() as request_session: hashed_token = security.hash_session_token_sha256(token) request.state.db_session = request_session - login_session: Sessions | None = crud.get_login_session_by_token_hash(request_session, hashed_token) + login_session: Sessions | None = crud.get_login_session_by_token_hash( + request_session, hashed_token + ) if not login_session: - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) - session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Unauthorized"}, + ) + session_expires_utc = login_session.expires_at.replace( + tzinfo=timezone.utc + ) if session_expires_utc < datetime.now(timezone.utc): - crud.delete_login_session(request_session, login_session.session_token_hash) - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) + crud.delete_login_session( + request_session, login_session.session_token_hash + ) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Unauthorized"}, + ) if login_session.user.is_active is False: - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) - if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600): - updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Unauthorized"}, + ) + if session_expires_utc - datetime.now(timezone.utc) < timedelta( + seconds=3600 + ): + updated_expiry = datetime.now(timezone.utc) + timedelta( + seconds=settings.settings.session_expiry_seconds + ) else: updated_expiry = session_expires_utc updated_session: SessionsUpdate = SessionsUpdate( @@ -88,46 +124,19 @@ class AuthMiddleWare(BaseHTTPMiddleware): ) user_id = login_session.user_id request.state.user_id = user_id - crud.update_login_session(request_session, hashed_token, update_session=updated_session) + crud.update_login_session( + request_session, hashed_token, update_session=updated_session + ) except Exception: logger.exception("Failed to authenticate user: \n") - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"}) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Internal server error"}, + ) return await call_next(request) -class ServiceError(Exception): - pass - - -class UserAlreadyExistsError(ServiceError): - pass - - -class ExchangeAlreadyExistsError(ServiceError): - pass - - -class ExchangeNotFoundError(ServiceError): - pass - - -class CycleNotFoundError(ServiceError): - pass - - -class TradeNotFoundError(ServiceError): - pass - - -class InvalidTradeDataError(ServiceError): - pass - - -class InvalidCycleDataError(ServiceError): - pass - - # User service def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): @@ -151,7 +160,9 @@ def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: return user -def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None: +def authenticate_user_service( + db_session: Session, user_in: UserLogin +) -> tuple[SessionsCreate, str] | None: user = crud.get_user_by_username(db_session, user_in.username) if not user: return None @@ -176,10 +187,14 @@ def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[ # Exchanges service -def create_exchange_service(db_session: Session, user_id: int, name: str, notes: str | None) -> ExchangesCreate: +def create_exchange_service( + db_session: Session, user_id: int, name: str, notes: str | None +) -> ExchangesCreate: existing_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id) if existing_exchange: - raise ExchangeAlreadyExistsError("Exchange with the same name already exists for this user") + raise ExchangeAlreadyExistsError( + "Exchange with the same name already exists for this user" + ) exchange_data = ExchangesCreate( user_id=user_id, name=name, @@ -198,12 +213,20 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes: return exchange_dto -def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]: +def get_exchanges_by_user_service( + db_session: Session, user_id: int +) -> list[ExchangesRead]: exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id) return [ExchangesRead.model_validate(exchange) for exchange in exchanges] -def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int, name: str | None, notes: str | None) -> ExchangesBase: +def update_exchanges_service( + db_session: Session, + user_id: int, + exchange_id: int, + name: str | None, + notes: str | None, +) -> ExchangesBase: existing_exchange = crud.get_exchange_by_id(db_session, exchange_id) if not existing_exchange: raise ExchangeNotFoundError("Exchange not found") @@ -211,16 +234,22 @@ def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int raise ExchangeNotFoundError("Exchange not found") if name: - other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id) + other_exchange = crud.get_exchange_by_name_and_user_id( + db_session, name, user_id + ) if other_exchange and other_exchange.id != existing_exchange.id: - raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user") + raise ExchangeAlreadyExistsError( + "Another exchange with the same name already exists for this user" + ) exchange_data = ExchangesBase( name=name or existing_exchange.name, notes=notes or existing_exchange.notes, ) try: - exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data) + exchange = crud.update_exchange( + db_session, cast("int", existing_exchange.id), update_data=exchange_data + ) except Exception as e: logger.exception("Failed to update exchange: \n") raise ServiceError("Failed to update exchange") from e @@ -228,7 +257,9 @@ def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int # Cycle Service -def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead: +def create_cycle_service( + db_session: Session, user_id: int, cycle_data: CycleBase +) -> CycleRead: raise NotImplementedError("Cycle creation not implemented") cycle_data_dict = cycle_data.model_dump() cycle_data_dict["user_id"] = user_id @@ -237,7 +268,9 @@ def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBas return CycleRead.model_validate(created_cycle) -def get_cycle_by_id_service(db_session: Session, user_id: int, cycle_id: int) -> CycleRead: +def get_cycle_by_id_service( + db_session: Session, user_id: int, cycle_id: int +) -> CycleRead: cycle = crud.get_cycle_by_id(db_session, cycle_id) if not cycle: raise CycleNotFoundError("Cycle not found") @@ -256,22 +289,65 @@ def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: # return False, "end_date is required when status is CLOSED" if cycle_data.status == "OPEN" and cycle_data.end_date is not None: return False, "end_date must be empty when status is OPEN" - if cycle_data.capital_exposure_cents is not None and cycle_data.capital_exposure_cents < 0: + if ( + cycle_data.capital_exposure_cents is not None + and cycle_data.capital_exposure_cents < 0 + ): return False, "capital_exposure_cents must be non-negative" if ( cycle_data.funding_source is not None and cycle_data.funding_source != "CASH" - and (cycle_data.loan_amount_cents is None or cycle_data.loan_interest_rate_tenth_bps is None) + and ( + cycle_data.loan_amount_cents is None + or cycle_data.loan_interest_rate_tenth_bps is None + ) ): - return False, "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + return ( + False, + "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH", + ) if cycle_data.loan_amount_cents is not None and cycle_data.loan_amount_cents < 0: return False, "loan_amount_cents must be non-negative" - if cycle_data.loan_interest_rate_tenth_bps is not None and cycle_data.loan_interest_rate_tenth_bps < 0: + if ( + cycle_data.loan_interest_rate_tenth_bps is not None + and cycle_data.loan_interest_rate_tenth_bps < 0 + ): return False, "loan_interest_rate_tenth_bps must be non-negative" return True, "" -def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpdate) -> CycleRead: +def _create_cycle_loan_event( + db_session: Session, + cycle_id: int, + loan_amount_cents: int | None, + loan_interest_rate_tenth_bps: int | None, +) -> None: + now = datetime.now(timezone.utc) + today = now.date() + existing_loan_event = crud.get_loan_event_by_cycle_id_and_effective_date( + db_session, cycle_id, today + ) + if existing_loan_event: + raise CycleLoanEventExistsError( + "A loan event with the same effective_date already exists for this cycle." + ) + loan_event_data = CycleLoanChangeEventBase( + cycle_id=cycle_id, + effective_date=today, + loan_amount_cents=loan_amount_cents, + loan_interest_rate_tenth_bps=loan_interest_rate_tenth_bps, + created_at=now, + ) + try: + crud.create_cycle_loan_event(db_session, loan_event_data) + except Exception as e: + logger.exception("Failed to create cycle loan event: \n") + raise ServiceError("Failed to create cycle loan event") from e + + +def update_cycle_service( + db_session: Session, user_id: int, cycle_data: CycleUpdate +) -> CycleRead: is_valid, err_msg = _validate_cycle_update_data(cycle_data) if not is_valid: raise InvalidCycleDataError(err_msg) @@ -281,22 +357,68 @@ def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpd raise CycleNotFoundError("Cycle not found") if existing_cycle.user_id != user_id: raise CycleNotFoundError("Cycle not found") + if ( + cycle_data.loan_amount_cents is not None + or cycle_data.loan_interest_rate_tenth_bps is not None + ): + _create_cycle_loan_event( + db_session, + cycle_id, + cycle_data.loan_amount_cents, + cycle_data.loan_interest_rate_tenth_bps, + ) provided_data_dict = cycle_data.model_dump(exclude_unset=True) cycle_data_with_user_id: CycleBase = CycleBase.model_validate(provided_data_dict) try: - updated_cycle = crud.update_cycle(db_session, cycle_id, update_data=cycle_data_with_user_id) + updated_cycle = crud.update_cycle( + db_session, cycle_id, update_data=cycle_data_with_user_id + ) except Exception as e: logger.exception("Failed to update cycle: \n") raise ServiceError("Failed to update cycle") from e return CycleRead.model_validate(updated_cycle) +def accural_interest_service(db_session: Session, cycle_id: int) -> None: + cycle = crud.get_cycle_by_id(db_session, cycle_id) + if not cycle: + logger.exception("Cycle not found for interest accrual") + raise CycleNotFoundError("Cycle not found") + if cycle.loan_amount_cents is None or cycle.loan_interest_rate_tenth_bps is None: + logger.info("Cycle has no loan, skipping interest accrual") + return + today = datetime.now(timezone.utc).date() + amount_cents = round( + cycle.loan_amount_cents * cycle.loan_interest_rate_tenth_bps / 100000 / 365 + ) + try: + crud.create_cycle_daily_accrual( + db_session, + cycle_id=cycle_id, + accrual_date=today, + accrual_amount_cents=amount_cents, + ) + except Exception as e: + logger.exception("Failed to create cycle interest accrual: \n") + raise ServiceError("Failed to create cycle interest accrual") from e + + +def flush_interest_accruals_service(db_session: Session) -> None: + pass + + # Trades service def _append_cashflows(trade_data: TradeCreate) -> TradeCreate: sign_multipler: int - if trade_data.trade_type in ("SELL_PUT", "SELL_CALL", "EXERCISE_CALL", "CLOSE_LONG_SPOT", "SHORT_SPOT"): + if trade_data.trade_type in ( + "SELL_PUT", + "SELL_CALL", + "EXERCISE_CALL", + "CLOSE_LONG_SPOT", + "SHORT_SPOT", + ): sign_multipler = 1 else: sign_multipler = -1 @@ -310,13 +432,18 @@ def _append_cashflows(trade_data: TradeCreate) -> TradeCreate: def _validate_trade_data(trade_data: TradeCreate) -> bool: return not ( - trade_data.trade_type in ("SELL_PUT", "SELL_CALL") and (trade_data.expiry_date is None or trade_data.strike_price_cents is None) + trade_data.trade_type in ("SELL_PUT", "SELL_CALL") + and (trade_data.expiry_date is None or trade_data.strike_price_cents is None) ) -def create_trade_service(db_session: Session, user_id: int, trade_data: TradeCreate) -> TradeRead: +def create_trade_service( + db_session: Session, user_id: int, trade_data: TradeCreate +) -> TradeRead: if not _validate_trade_data(trade_data): - raise InvalidTradeDataError("Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades") + raise InvalidTradeDataError( + "Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades" + ) trade_data_dict = trade_data.model_dump() trade_data_dict["user_id"] = user_id trade_data_with_user_id: TradeCreate = TradeCreate.model_validate(trade_data_dict) @@ -325,7 +452,9 @@ def create_trade_service(db_session: Session, user_id: int, trade_data: TradeCre return TradeRead.model_validate(created_trade) -def get_trade_by_id_service(db_session: Session, user_id: int, trade_id: int) -> TradeRead: +def get_trade_by_id_service( + db_session: Session, user_id: int, trade_id: int +) -> TradeRead: trade = crud.get_trade_by_id(db_session, trade_id) if not trade: raise TradeNotFoundError("Trade not found") @@ -334,21 +463,27 @@ def get_trade_by_id_service(db_session: Session, user_id: int, trade_id: int) -> return TradeRead.model_validate(trade) -def update_trade_friendly_name_service(db_session: Session, user_id: int, trade_id: int, friendly_name: str) -> TradeRead: +def update_trade_friendly_name_service( + db_session: Session, user_id: int, trade_id: int, friendly_name: str +) -> TradeRead: existing_trade = crud.get_trade_by_id(db_session, trade_id) if not existing_trade: raise TradeNotFoundError("Trade not found") if existing_trade.user_id != user_id: raise TradeNotFoundError("Trade not found") try: - updated_trade = crud.update_trade_friendly_name(db_session, trade_id, friendly_name) + updated_trade = crud.update_trade_friendly_name( + db_session, trade_id, friendly_name + ) except Exception as e: logger.exception("Failed to update trade friendly name: \n") raise ServiceError("Failed to update trade friendly name") from e return TradeRead.model_validate(updated_trade) -def update_trade_note_service(db_session: Session, user_id: int, trade_id: int, note: str | None) -> TradeRead: +def update_trade_note_service( + db_session: Session, user_id: int, trade_id: int, note: str | None +) -> TradeRead: existing_trade = crud.get_trade_by_id(db_session, trade_id) if not existing_trade: raise TradeNotFoundError("Trade not found") diff --git a/backend/trading_journal/service_error.py b/backend/trading_journal/service_error.py new file mode 100644 index 0000000..10f4553 --- /dev/null +++ b/backend/trading_journal/service_error.py @@ -0,0 +1,34 @@ +class ServiceError(Exception): + pass + + +class UserAlreadyExistsError(ServiceError): + pass + + +class ExchangeAlreadyExistsError(ServiceError): + pass + + +class ExchangeNotFoundError(ServiceError): + pass + + +class CycleNotFoundError(ServiceError): + pass + + +class TradeNotFoundError(ServiceError): + pass + + +class InvalidTradeDataError(ServiceError): + pass + + +class InvalidCycleDataError(ServiceError): + pass + + +class CycleLoanEventExistsError(ServiceError): + pass