diff --git a/backend/tests/test_service.py b/backend/tests/test_service.py index cf0fc8b..27ba8b3 100644 --- a/backend/tests/test_service.py +++ b/backend/tests/test_service.py @@ -25,15 +25,9 @@ class FakeDBFactory: yield fake_session -def verify_json_response( - response: Response, expected_status: int, expected_detail: str -) -> None: +def verify_json_response(response: Response, expected_status: int, expected_detail: str) -> None: assert response.status_code == expected_status - body_bytes = ( - response.body.tobytes() - if isinstance(response.body, memoryview) - else response.body - ) + body_bytes = response.body.tobytes() if isinstance(response.body, memoryview) else response.body body_text = body_bytes.decode("utf-8") body_json = json.loads(body_text) assert body_json.get("detail") == expected_detail @@ -99,9 +93,7 @@ def test_auth_middleware_no_db() -> None: pytest.fail("call_next should not be called for invalid token") response = asyncio.run(middleware.dispatch(request, call_next)) - verify_json_response( - response, status.HTTP_500_INTERNAL_SERVER_ERROR, "db factory not configured" - ) + verify_json_response(response, status.HTTP_500_INTERNAL_SERVER_ERROR, "db factory not configured") def test_auth_middleware_rejects_invalid_token() -> None: @@ -122,9 +114,7 @@ def test_auth_middleware_rejects_invalid_token() -> None: async def call_next(req: Request) -> Response: # noqa: ARG001 pytest.fail("call_next should not be called for invalid token") - with patch( - "trading_journal.crud.get_login_session_by_token_hash", return_value=None - ): + with patch("trading_journal.crud.get_login_session_by_token_hash", return_value=None): response = asyncio.run(middleware.dispatch(request, call_next)) verify_json_response(response, status.HTTP_401_UNAUTHORIZED, "Unauthorized") @@ -335,12 +325,8 @@ def test_auth_middleware_allows_valid_token_and_updates_expires() -> None: _, kwargs = mock_update.call_args update_session = kwargs.get("update_session") assert update_session is not None - assert ( - update_session.expires_at - datetime.now(timezone.utc) - ).total_seconds() > settings.session_expiry_seconds - 1 - assert ( - update_session.last_seen_at - datetime.now(timezone.utc) - ).total_seconds() < 1 + assert (update_session.expires_at - datetime.now(timezone.utc)).total_seconds() > settings.session_expiry_seconds - 1 + assert (update_session.last_seen_at - datetime.now(timezone.utc)).total_seconds() < 1 assert update_session.last_used_ip == "testclient" assert update_session.user_agent == "test-agent" @@ -354,14 +340,10 @@ def test_register_user_success() -> None: } with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_user_by_username", return_value=None - ) as mock_get, + patch("trading_journal.crud.get_user_by_username", return_value=None) as mock_get, patch( "trading_journal.crud.create_user", - return_value=SimpleNamespace( - id=1, username=user_in.username, is_active=True - ), + return_value=SimpleNamespace(id=1, username=user_in.username, is_active=True), ) as mock_create, patch( "trading_journal.security.hash_password", @@ -381,9 +363,7 @@ def test_register_user_exists_raises() -> None: FakeDBFactory().get_session_ctx_manager() as db, patch( "trading_journal.crud.get_user_by_username", - return_value=SimpleNamespace( - id=1, username=user_in.username, is_active=True - ), + return_value=SimpleNamespace(id=1, username=user_in.username, is_active=True), ) as mock_get, ): with pytest.raises(service.UserAlreadyExistsError) as exc_info: @@ -394,13 +374,10 @@ def test_register_user_exists_raises() -> None: def test_authenticate_user_success() -> None: user_in = dto.UserLogin(username="validuser", password="validpassword") - stored_user = SimpleNamespace( - id=1, username=user_in.username, is_active=True, password_hash="hashedpassword" - ) + stored_user = SimpleNamespace(id=1, username=user_in.username, is_active=True, password_hash="hashedpassword") expected_login_session = dto.SessionsCreate( user_id=stored_user.id, - expires_at=datetime.now(timezone.utc) - + timedelta(seconds=settings.session_expiry_seconds), + expires_at=datetime.now(timezone.utc) + timedelta(seconds=settings.session_expiry_seconds), ) with ( FakeDBFactory().get_session_ctx_manager() as db, @@ -408,9 +385,7 @@ def test_authenticate_user_success() -> None: "trading_journal.crud.get_user_by_username", return_value=stored_user, ) as mock_get, - patch( - "trading_journal.security.verify_password", return_value=True - ) as mock_verify, + patch("trading_journal.security.verify_password", return_value=True) as mock_verify, patch( "trading_journal.security.generate_session_token", return_value="newsessiontoken", @@ -421,9 +396,7 @@ def test_authenticate_user_success() -> None: ) as mock_hash_session_token, patch( "trading_journal.crud.create_login_session", - return_value=SimpleNamespace( - user_id=stored_user.id, expires_at=expected_login_session.expires_at - ), + return_value=SimpleNamespace(user_id=stored_user.id, expires_at=expected_login_session.expires_at), ) as mock_create_session, ): user_out = service.authenticate_user_service(db, user_in) @@ -432,14 +405,7 @@ def test_authenticate_user_success() -> None: # assert fields instead of direct equality to avoid pydantic/model issues assert getattr(login_session, "user_id", None) == stored_user.id assert isinstance(getattr(login_session, "expires_at", None), datetime) - assert ( - abs( - ( - login_session.expires_at - expected_login_session.expires_at - ).total_seconds() - ) - < 2 - ) + assert abs((login_session.expires_at - expected_login_session.expires_at).total_seconds()) < 2 assert token == "newsessiontoken" assert login_session.user_id == stored_user.id mock_get.assert_called_once_with(db, user_in.username) @@ -470,18 +436,14 @@ def test_authenticate_user_not_found_returns_none() -> None: def test_authenticate_user_invalid_password_returns_none() -> None: user_in = dto.UserLogin(username="validuser", password="invalidpassword") - stored_user = SimpleNamespace( - id=1, username=user_in.username, is_active=True, password_hash="hashedpassword" - ) + stored_user = SimpleNamespace(id=1, username=user_in.username, is_active=True, password_hash="hashedpassword") with ( FakeDBFactory().get_session_ctx_manager() as db, patch( "trading_journal.crud.get_user_by_username", return_value=stored_user, ) as mock_get, - patch( - "trading_journal.security.verify_password", return_value=False - ) as mock_verify, + patch("trading_journal.security.verify_password", return_value=False) as mock_verify, ): user_out = service.authenticate_user_service(db, user_in) assert user_out is None @@ -496,9 +458,7 @@ def test_create_exchange_duplicate_raises() -> None: FakeDBFactory().get_session_ctx_manager() as db, patch( "trading_journal.crud.get_exchange_by_name_and_user_id", - return_value=SimpleNamespace( - id=1, user_id=1, name=exchange_in.name, notes="Existing exchange" - ), + return_value=SimpleNamespace(id=1, user_id=1, name=exchange_in.name, notes="Existing exchange"), ) as mock_get, ): with pytest.raises(service.ExchangeAlreadyExistsError) as exc_info: @@ -508,10 +468,7 @@ def test_create_exchange_duplicate_raises() -> None: name=exchange_in.name, notes=exchange_in.notes, ) - assert ( - str(exc_info.value) - == "Exchange with the same name already exists for this user" - ) + assert str(exc_info.value) == "Exchange with the same name already exists for this user" mock_get.assert_called_once_with(db, exchange_in.name, exchange_in.user_id) @@ -552,9 +509,7 @@ def test_get_exchanges_by_user_id() -> None: "trading_journal.crud.get_all_exchanges_by_user_id", return_value=[ SimpleNamespace(id=1, user_id=1, name="NYSE", notes="First exchange"), - SimpleNamespace( - id=2, user_id=1, name="NASDAQ", notes="Second exchange" - ), + SimpleNamespace(id=2, user_id=1, name="NASDAQ", notes="Second exchange"), ], ) as mock_get, ): @@ -601,9 +556,7 @@ def test_update_exchange_not_found() -> None: def test_update_exchange_owner_mismatch_raises() -> None: exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") - existing_exchange = SimpleNamespace( - id=1, user_id=2, name="OldName", notes="Old notes" - ) + existing_exchange = SimpleNamespace(id=1, user_id=2, name="OldName", notes="Old notes") with ( FakeDBFactory().get_session_ctx_manager() as db, patch( @@ -625,9 +578,7 @@ def test_update_exchange_owner_mismatch_raises() -> None: def test_update_exchange_duplication() -> None: exchange_update = dto.ExchangesBase(name="DuplicateName", notes="Updated notes") - existing_exchange = SimpleNamespace( - id=1, user_id=1, name="OldName", notes="Old notes" - ) + existing_exchange = SimpleNamespace(id=1, user_id=1, name="OldName", notes="Old notes") with ( FakeDBFactory().get_session_ctx_manager() as db, patch( @@ -636,9 +587,7 @@ def test_update_exchange_duplication() -> None: ) as mock_get, patch( "trading_journal.crud.get_exchange_by_name_and_user_id", - return_value=SimpleNamespace( - id=2, user_id=1, name="DuplicateName", notes="Another exchange" - ), + return_value=SimpleNamespace(id=2, user_id=1, name="DuplicateName", notes="Another exchange"), ) as mock_get_by_name, ): with pytest.raises(service.ExchangeAlreadyExistsError) as exc_info: @@ -649,19 +598,14 @@ def test_update_exchange_duplication() -> None: name=exchange_update.name, notes=exchange_update.notes, ) - assert ( - str(exc_info.value) - == "Another exchange with the same name already exists for this user" - ) + assert str(exc_info.value) == "Another exchange with the same name already exists for this user" mock_get.assert_called_once_with(db, 1) mock_get_by_name.assert_called_once_with(db, "DuplicateName", 1) def test_update_exchange_success() -> None: exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") - existing_exchange = SimpleNamespace( - id=1, user_id=1, name="OldName", notes="Old notes" - ) + existing_exchange = SimpleNamespace(id=1, user_id=1, name="OldName", notes="Old notes") with ( FakeDBFactory().get_session_ctx_manager() as db, patch( @@ -674,9 +618,7 @@ def test_update_exchange_success() -> None: ) as mock_get_by_name, patch( "trading_journal.crud.update_exchange", - return_value=SimpleNamespace( - id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes - ), + return_value=SimpleNamespace(id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes), ) as mock_update, ): exchange_out = service.update_exchanges_service( @@ -741,9 +683,7 @@ def test_get_cycle_by_id_success() -> None: FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_cycle_by_id", return_value=cycle) as mock_get, ): - cycle_out = service.get_cycle_by_id_service( - db, user_id=user_id, cycle_id=cycle_id - ) + cycle_out = service.get_cycle_by_id_service(db, user_id=user_id, cycle_id=cycle_id) assert cycle_out.id == cycle_id assert cycle_out.user_id == user_id assert cycle_out.friendly_name == "Test Cycle" @@ -760,9 +700,7 @@ def test_get_cycles_by_user_no_cycles() -> None: user_id = 1 with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_cycles_by_user_id", return_value=[] - ) as mock_get, + patch("trading_journal.crud.get_cycles_by_user_id", return_value=[]) as mock_get, ): cycles = service.get_cycles_by_user_service(db, user_id=user_id) assert isinstance(cycles, list) @@ -800,9 +738,7 @@ def test_get_cycles_by_user_with_cycles() -> None: ) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_cycles_by_user_id", return_value=[cycle1, cycle2] - ) as mock_get, + patch("trading_journal.crud.get_cycles_by_user_id", return_value=[cycle1, cycle2]) as mock_get, ): cycles = service.get_cycles_by_user_service(db, user_id=user_id) assert isinstance(cycles, list) @@ -840,9 +776,7 @@ def test_update_cycle_open_status_mismatch_raises() -> None: def test_update_cycle_invalid_capital_exposure_raises() -> None: - cycle_data = dto.CycleUpdate( - id=1, friendly_name="Updated Cycle", status="OPEN", capital_exposure_cents=-100 - ) + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", capital_exposure_cents=-100) with ( FakeDBFactory().get_session_ctx_manager() as db, ): @@ -864,10 +798,7 @@ def test_update_cycle_no_cash_no_loan_raises() -> None: ): with pytest.raises(service.InvalidCycleDataError) as exc_info: service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) - assert ( - str(exc_info.value) - == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" - ) + assert str(exc_info.value) == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" def test_update_cycle_loan_missing_interest_raises() -> None: @@ -883,10 +814,7 @@ def test_update_cycle_loan_missing_interest_raises() -> None: ): with pytest.raises(service.InvalidCycleDataError) as exc_info: service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) - assert ( - str(exc_info.value) - == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" - ) + assert str(exc_info.value) == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" def test_update_cycle_loan_negative_loan_raises() -> None: @@ -920,9 +848,7 @@ def test_update_cycle_loan_negative_interest_raises() -> None: ): with pytest.raises(service.InvalidCycleDataError) as exc_info: service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) - assert ( - str(exc_info.value) == "loan_interest_rate_tenth_bps must be non-negative" - ) + assert str(exc_info.value) == "loan_interest_rate_tenth_bps must be non-negative" def test_update_cycle_not_found_raises() -> None: @@ -942,9 +868,7 @@ def test_update_cycle_owner_mismatch_raises() -> None: existing_cycle = SimpleNamespace(id=1, user_id=2) # Owned by different user with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_cycle_by_id", return_value=existing_cycle - ) as mock_get, + patch("trading_journal.crud.get_cycle_by_id", return_value=existing_cycle) as mock_get, ): with pytest.raises(service.CycleNotFoundError) as exc_info: service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) @@ -993,16 +917,13 @@ def test_update_cycle_success() -> None: ) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_cycle_by_id", return_value=existing_cycle - ) as mock_get, - patch( - "trading_journal.crud.update_cycle", return_value=updated_cycle - ) as mock_update, + patch("trading_journal.crud.get_cycle_by_id", return_value=existing_cycle) as mock_get, + patch("trading_journal.crud.update_cycle", return_value=updated_cycle) as mock_update, patch( "trading_journal.crud.get_loan_event_by_cycle_id_and_effective_date", return_value=None, ) as mock_get_loan_event, + patch("trading_journal.crud.create_cycle_loan_event"), ): cycle_out = service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) assert cycle_out.id == updated_cycle.id @@ -1011,10 +932,7 @@ def test_update_cycle_success() -> None: assert cycle_out.funding_source == updated_cycle.funding_source assert cycle_out.capital_exposure_cents == updated_cycle.capital_exposure_cents assert cycle_out.loan_amount_cents == updated_cycle.loan_amount_cents - assert ( - cycle_out.loan_interest_rate_tenth_bps - == updated_cycle.loan_interest_rate_tenth_bps - ) + assert cycle_out.loan_interest_rate_tenth_bps == updated_cycle.loan_interest_rate_tenth_bps mock_get.assert_called_once_with(db, cycle_data.id) update_cycle_base = dto.CycleBase( friendly_name=cycle_data.friendly_name, @@ -1022,17 +940,49 @@ def test_update_cycle_success() -> None: funding_source=cycle_data.funding_source, capital_exposure_cents=cycle_data.capital_exposure_cents, loan_amount_cents=getattr(cycle_data, "loan_amount_cents", None), - loan_interest_rate_tenth_bps=getattr( - cycle_data, "loan_interest_rate_tenth_bps", None - ), + loan_interest_rate_tenth_bps=getattr(cycle_data, "loan_interest_rate_tenth_bps", None), end_date=getattr(cycle_data, "end_date", None), ) - mock_update.assert_called_once_with( - db, cycle_data.id, update_data=update_cycle_base - ) + mock_update.assert_called_once_with(db, cycle_data.id, update_data=update_cycle_base) mock_get_loan_event.assert_called_once_with(db, cycle_data.id, today) +def test_accrual_interest_service_cycle_not_found() -> None: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=None) as mock_get_cycle, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.accrual_interest_service(db, cycle_id=1) + assert str(exc_info.value) == "Cycle not found" + mock_get_cycle.assert_called_once_with(db, 1) + + +def test_accrual_interest_service_success() -> None: + today = datetime.now(timezone.utc).date() + cycle = SimpleNamespace( + id=1, + user_id=1, + friendly_name="Test Cycle", + status="OPEN", + funding_source="LOAN", + loan_amount_cents=200000, + loan_interest_rate_tenth_bps=500, # 0.5% + start_date=today - timedelta(days=10), + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=cycle) as mock_get_cycle, + patch( + "trading_journal.crud.create_cycle_daily_accrual", + return_value=None, + ) as mock_accrual_interest, + ): + service.accrual_interest_service(db, cycle_id=1) + mock_get_cycle.assert_called_once_with(db, 1) + mock_accrual_interest.assert_called_once_with(db, cycle_id=1, accrual_date=today, accrual_amount_cents=3) + + # --- Trade services ---------------------------------------------------------- def test_create_trade_short_option_no_strike() -> None: trade_data = dto.TradeCreate( @@ -1058,10 +1008,7 @@ def test_create_trade_short_option_no_strike() -> None: ): with pytest.raises(service.InvalidTradeDataError) as exc_info: service.create_trade_service(db, 1, trade_data) - assert ( - str(exc_info.value) - == "Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades" - ) + assert str(exc_info.value) == "Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades" def test_create_trade_success() -> None: @@ -1104,9 +1051,7 @@ def test_create_trade_success() -> None: ) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.create_trade", return_value=created_trade - ) as mock_create_trade, + patch("trading_journal.crud.create_trade", return_value=created_trade) as mock_create_trade, ): trade_out = service.create_trade_service(db, user_id=1, trade_data=trade_data) assert trade_out.id == created_trade.id @@ -1115,18 +1060,10 @@ def test_create_trade_success() -> None: assert trade_out.trade_type == created_trade.trade_type mock_create_trade.assert_called_once() _, kwargs = mock_create_trade.call_args - passed_trade = kwargs.get("trade_data") or ( - mock_create_trade.call_args[0][1] - if len(mock_create_trade.call_args[0]) > 1 - else None - ) + passed_trade = kwargs.get("trade_data") or (mock_create_trade.call_args[0][1] if len(mock_create_trade.call_args[0]) > 1 else None) assert passed_trade is not None # expected for SELL_PUT: gross = quantity * price * quantity_multiplier (positive), net = gross - commission - expected_gross = ( - trade_data.quantity - * trade_data.price_cents - * (trade_data.quantity_multiplier or 1) - ) + expected_gross = trade_data.quantity * trade_data.price_cents * (trade_data.quantity_multiplier or 1) expected_net = expected_gross - trade_data.commission_cents assert getattr(passed_trade, "gross_cash_flow_cents", None) == expected_gross assert getattr(passed_trade, "net_cash_flow_cents", None) == expected_net @@ -1147,9 +1084,7 @@ def test_get_trade_by_id_not_found_owner_mismatch() -> None: existing_trade = SimpleNamespace(id=2, user_id=2) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_trade_by_id", return_value=existing_trade - ) as mock_get, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, ): with pytest.raises(service.TradeNotFoundError) as exc_info: service.get_trade_by_id_service(db, user_id=1, trade_id=2) @@ -1186,9 +1121,7 @@ def test_get_trade_by_id_success() -> None: ) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_trade_by_id", return_value=trade_obj - ) as mock_get, + patch("trading_journal.crud.get_trade_by_id", return_value=trade_obj) as mock_get, ): res = service.get_trade_by_id_service(db, user_id=1, trade_id=10) assert res.id == trade_obj.id @@ -1204,9 +1137,7 @@ def test_update_trade_friendly_name_not_found() -> None: patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get, ): with pytest.raises(service.TradeNotFoundError) as exc_info: - service.update_trade_friendly_name_service( - db, user_id=1, trade_id=10, friendly_name="New Name" - ) + service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Name") assert str(exc_info.value) == "Trade not found" mock_get.assert_called_once_with(db, 10) @@ -1215,14 +1146,10 @@ def test_update_trade_friendly_name_owner_mismatch_raises() -> None: existing_trade = SimpleNamespace(id=10, user_id=2) # owned by another user with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_trade_by_id", return_value=existing_trade - ) as mock_get, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, ): with pytest.raises(service.TradeNotFoundError) as exc_info: - service.update_trade_friendly_name_service( - db, user_id=1, trade_id=10, friendly_name="New Name" - ) + service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Name") assert str(exc_info.value) == "Trade not found" mock_get.assert_called_once_with(db, 10) @@ -1253,22 +1180,16 @@ def test_update_trade_friendly_name_success() -> None: notes="ok", cycle_id=None, ) - updated_trade = SimpleNamespace( - **{**existing_trade.__dict__, "friendly_name": "New Friendly"} - ) + updated_trade = SimpleNamespace(**{**existing_trade.__dict__, "friendly_name": "New Friendly"}) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_trade_by_id", return_value=existing_trade - ) as mock_get, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, patch( "trading_journal.crud.update_trade_friendly_name", return_value=updated_trade, ) as mock_update, ): - res = service.update_trade_friendly_name_service( - db, user_id=1, trade_id=10, friendly_name="New Friendly" - ) + res = service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Friendly") assert res.friendly_name == "New Friendly" mock_get.assert_called_once_with(db, 10) mock_update.assert_called_once_with(db, 10, "New Friendly") @@ -1289,9 +1210,7 @@ def test_update_trade_note_owner_mismatch_raises() -> None: existing_trade = SimpleNamespace(id=20, user_id=2) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_trade_by_id", return_value=existing_trade - ) as mock_get, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, ): with pytest.raises(service.TradeNotFoundError) as exc_info: service.update_trade_note_service(db, user_id=1, trade_id=20, note="x") @@ -1328,12 +1247,8 @@ def test_update_trade_note_success_and_none_becomes_empty() -> None: updated_trade = SimpleNamespace(**{**existing_trade.__dict__, "notes": ""}) with ( FakeDBFactory().get_session_ctx_manager() as db, - patch( - "trading_journal.crud.get_trade_by_id", return_value=existing_trade - ) as mock_get, - patch( - "trading_journal.crud.update_trade_note", return_value=updated_trade - ) as mock_update, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + patch("trading_journal.crud.update_trade_note", return_value=updated_trade) as mock_update, ): res = service.update_trade_note_service(db, user_id=1, trade_id=20, note=None) assert res.notes == "" diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index b771433..8e3baee 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -52,9 +52,7 @@ def _data_to_dict(data: AnyModel) -> dict[str, AnyModel]: # Trades -def create_trade( - session: Session, trade_data: Mapping[str, Any] | BaseModel -) -> models.Trades: +def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> models.Trades: data = _data_to_dict(trade_data) allowed = _allowed_columns(models.Trades) payload = {k: v for k, v in data.items() if k in allowed} @@ -74,19 +72,13 @@ def create_trade( raise ValueError("exchange.user_id does not match trade.user_id") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") - payload["underlying_currency"] = _check_enum( - models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency" - ) + payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") if "trade_type" not in payload: raise ValueError("trade_type is required") - payload["trade_type"] = _check_enum( - models.TradeType, payload["trade_type"], "trade_type" - ) + payload["trade_type"] = _check_enum(models.TradeType, payload["trade_type"], "trade_type") if "trade_strategy" not in payload: raise ValueError("trade_strategy is required") - payload["trade_strategy"] = _check_enum( - models.TradeStrategy, payload["trade_strategy"], "trade_strategy" - ) + payload["trade_strategy"] = _check_enum(models.TradeStrategy, payload["trade_strategy"], "trade_strategy") # trade_time_utc is the creation moment: always set to now (caller shouldn't provide) now = datetime.now(timezone.utc) payload.pop("trade_time_utc", None) @@ -113,8 +105,7 @@ def create_trade( "symbol": payload["symbol"], "exchange_id": payload["exchange_id"], "underlying_currency": payload["underlying_currency"], - "friendly_name": "Auto-created Cycle by trade " - + payload.get("friendly_name", ""), + "friendly_name": "Auto-created Cycle by trade " + payload.get("friendly_name", ""), "status": models.CycleStatus.OPEN, "start_date": payload["trade_date"], } @@ -128,9 +119,7 @@ def create_trade( if cycle is None: raise ValueError("cycle_id does not exist") - payload.pop( - "exchange_id", None - ) # ignore exchange_id if provided; use cycle's exchange_id + payload.pop("exchange_id", None) # ignore exchange_id if provided; use cycle's exchange_id payload["exchange_id"] = cycle.exchange_id if cycle.user_id != user_id: raise ValueError("cycle.user_id does not match trade.user_id") @@ -158,9 +147,7 @@ def get_trade_by_id(session: Session, trade_id: int) -> models.Trades | None: return session.get(models.Trades, trade_id) -def get_trade_by_user_id_and_friendly_name( - session: Session, user_id: int, friendly_name: str -) -> models.Trades | None: +def get_trade_by_user_id_and_friendly_name(session: Session, user_id: int, friendly_name: str) -> models.Trades | None: statement = select(models.Trades).where( models.Trades.user_id == user_id, models.Trades.friendly_name == friendly_name, @@ -175,9 +162,7 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades] return list(session.exec(statement).all()) -def update_trade_friendly_name( - session: Session, trade_id: int, friendly_name: str -) -> models.Trades: +def update_trade_friendly_name(session: Session, trade_id: int, friendly_name: str) -> models.Trades: trade: models.Trades | None = session.get(models.Trades, trade_id) if trade is None: raise ValueError("trade_id does not exist") @@ -225,9 +210,7 @@ def invalidate_trade(session: Session, trade_id: int) -> models.Trades: return trade -def replace_trade( - session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel -) -> models.Trades: +def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel) -> models.Trades: invalidate_trade(session, old_trade_id) data = _data_to_dict(new_trade_data) data["replaced_by_trade_id"] = old_trade_id @@ -235,9 +218,7 @@ def replace_trade( # Cycles -def create_cycle( - session: Session, cycle_data: Mapping[str, Any] | BaseModel -) -> models.Cycles: +def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> models.Cycles: data = _data_to_dict(cycle_data) allowed = _allowed_columns(models.Cycles) payload = {k: v for k, v in data.items() if k in allowed} @@ -255,9 +236,7 @@ def create_cycle( raise ValueError("exchange.user_id does not match cycle.user_id") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") - payload["underlying_currency"] = _check_enum( - models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency" - ) + payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") if "status" not in payload: raise ValueError("status is required") payload["status"] = _check_enum(models.CycleStatus, payload["status"], "status") @@ -289,9 +268,7 @@ def get_cycles_by_user_id(session: Session, user_id: int) -> list[models.Cycles] return list(session.exec(statement).all()) -def update_cycle( - session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel -) -> models.Cycles: +def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles: cycle: models.Cycles | None = session.get(models.Cycles, cycle_id) if cycle is None: raise ValueError("cycle_id does not exist") @@ -327,9 +304,7 @@ def update_cycle( # Cycle loan and interest -def create_cycle_loan_event( - session: Session, loan_data: Mapping[str, Any] | BaseModel -) -> models.CycleLoanChangeEvents: +def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents: data = _data_to_dict(loan_data) allowed = _allowed_columns(models.CycleLoanChangeEvents) payload = {k: v for k, v in data.items() if k in allowed} @@ -339,9 +314,7 @@ def create_cycle_loan_event( if cycle is None: raise ValueError("cycle_id does not exist") - payload["effective_date"] = ( - payload.get("effective_date") or datetime.now(timezone.utc).date() - ) + payload["effective_date"] = payload.get("effective_date") or datetime.now(timezone.utc).date() payload["created_at"] = datetime.now(timezone.utc) cle = models.CycleLoanChangeEvents(**payload) session.add(cle) @@ -354,9 +327,7 @@ def create_cycle_loan_event( return cle -def get_loan_events_by_cycle_id( - session: Session, cycle_id: int -) -> list[models.CycleLoanChangeEvents]: +def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleLoanChangeEvents]: eff_col = cast("ColumnElement", models.CycleLoanChangeEvents.effective_date) id_col = cast("ColumnElement", models.CycleLoanChangeEvents.id) statement = ( @@ -369,9 +340,7 @@ def get_loan_events_by_cycle_id( return list(session.exec(statement).all()) -def get_loan_event_by_cycle_id_and_effective_date( - session: Session, cycle_id: int, effective_date: date -) -> models.CycleLoanChangeEvents | None: +def get_loan_event_by_cycle_id_and_effective_date(session: Session, cycle_id: int, effective_date: date) -> models.CycleLoanChangeEvents | None: statement = select(models.CycleLoanChangeEvents).where( models.CycleLoanChangeEvents.cycle_id == cycle_id, models.CycleLoanChangeEvents.effective_date == effective_date, @@ -379,12 +348,8 @@ def get_loan_event_by_cycle_id_and_effective_date( return session.exec(statement).first() -def update_cycle_loan_event( - session: Session, event_id: int, update_data: Mapping[str, Any] | BaseModel -) -> models.CycleLoanChangeEvents: - event: models.CycleLoanChangeEvents | None = session.get( - models.CycleLoanChangeEvents, event_id - ) +def update_cycle_loan_event(session: Session, event_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents: + event: models.CycleLoanChangeEvents | None = session.get(models.CycleLoanChangeEvents, event_id) if event is None: raise ValueError("event_id does not exist") data = _data_to_dict(update_data) @@ -406,9 +371,7 @@ def update_cycle_loan_event( return event -def create_cycle_daily_accrual( - session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int -) -> models.CycleDailyAccrual: +def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int) -> models.CycleDailyAccrual: cycle = session.get(models.Cycles, cycle_id) if cycle is None: raise ValueError("cycle_id does not exist") @@ -438,9 +401,7 @@ def create_cycle_daily_accrual( return row -def get_cycle_daily_accruals_by_cycle_id( - session: Session, cycle_id: int -) -> list[models.CycleDailyAccrual]: +def get_cycle_daily_accruals_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleDailyAccrual]: date_col = cast("ColumnElement", models.CycleDailyAccrual.accrual_date) statement = ( select(models.CycleDailyAccrual) @@ -452,9 +413,7 @@ def get_cycle_daily_accruals_by_cycle_id( return list(session.exec(statement).all()) -def get_cycle_daily_accrual_by_cycle_id_and_date( - session: Session, cycle_id: int, accrual_date: date -) -> models.CycleDailyAccrual | None: +def get_cycle_daily_accrual_by_cycle_id_and_date(session: Session, cycle_id: int, accrual_date: date) -> models.CycleDailyAccrual | None: statement = select(models.CycleDailyAccrual).where( models.CycleDailyAccrual.cycle_id == cycle_id, models.CycleDailyAccrual.accrual_date == accrual_date, @@ -466,9 +425,7 @@ def get_cycle_daily_accrual_by_cycle_id_and_date( IMMUTABLE_EXCHANGE_FIELDS = {"id"} -def create_exchange( - session: Session, exchange_data: Mapping[str, Any] | BaseModel -) -> models.Exchanges: +def create_exchange(session: Session, exchange_data: Mapping[str, Any] | BaseModel) -> models.Exchanges: data = _data_to_dict(exchange_data) allowed = _allowed_columns(models.Exchanges) payload = {k: v for k, v in data.items() if k in allowed} @@ -490,9 +447,7 @@ def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | return session.get(models.Exchanges, exchange_id) -def get_exchange_by_name_and_user_id( - session: Session, name: str, user_id: int -) -> models.Exchanges | None: +def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) -> models.Exchanges | None: statement = select(models.Exchanges).where( models.Exchanges.name == name, models.Exchanges.user_id == user_id, @@ -505,18 +460,14 @@ def get_all_exchanges(session: Session) -> list[models.Exchanges]: return list(session.exec(statement).all()) -def get_all_exchanges_by_user_id( - session: Session, user_id: int -) -> list[models.Exchanges]: +def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]: statement = select(models.Exchanges).where( models.Exchanges.user_id == user_id, ) return list(session.exec(statement).all()) -def update_exchange( - session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel -) -> models.Exchanges: +def update_exchange(session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Exchanges: exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) if exchange is None: raise ValueError("exchange_id does not exist") @@ -553,9 +504,7 @@ def delete_exchange(session: Session, exchange_id: int) -> None: IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"} -def create_user( - session: Session, user_data: Mapping[str, Any] | BaseModel -) -> models.Users: +def create_user(session: Session, user_data: Mapping[str, Any] | BaseModel) -> models.Users: data = _data_to_dict(user_data) allowed = _allowed_columns(models.Users) payload = {k: v for k, v in data.items() if k in allowed} @@ -586,9 +535,7 @@ def get_user_by_username(session: Session, username: str) -> models.Users | None return session.exec(statement).first() -def update_user( - session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel -) -> models.Users: +def update_user(session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Users: user: models.Users | None = session.get(models.Users, user_id) if user is None: raise ValueError("user_id does not exist") @@ -645,9 +592,7 @@ def create_login_session( return s -def get_login_session_by_token_hash_and_user_id( - session: Session, session_token_hash: str, user_id: int -) -> models.Sessions | None: +def get_login_session_by_token_hash_and_user_id(session: Session, session_token_hash: str, user_id: int) -> models.Sessions | None: statement = select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hash, models.Sessions.user_id == user_id, @@ -657,9 +602,7 @@ def get_login_session_by_token_hash_and_user_id( return session.exec(statement).first() -def get_login_session_by_token_hash( - session: Session, session_token_hash: str -) -> models.Sessions | None: +def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None: statement = select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hash, models.Sessions.expires_at > datetime.now(timezone.utc), diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index c83bfb7..52e5aea 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -56,9 +56,7 @@ logger = logging.getLogger(__name__) class AuthMiddleWare(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: # noqa: PLR0911 + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: PLR0911 if request.url.path in EXCEPT_PATHS: return await call_next(request) @@ -84,21 +82,15 @@ class AuthMiddleWare(BaseHTTPMiddleware): with db_factory.get_session_ctx_manager() as request_session: hashed_token = security.hash_session_token_sha256(token) request.state.db_session = request_session - login_session: Sessions | None = crud.get_login_session_by_token_hash( - request_session, hashed_token - ) + login_session: Sessions | None = crud.get_login_session_by_token_hash(request_session, hashed_token) if not login_session: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}, ) - session_expires_utc = login_session.expires_at.replace( - tzinfo=timezone.utc - ) + session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc) if session_expires_utc < datetime.now(timezone.utc): - crud.delete_login_session( - request_session, login_session.session_token_hash - ) + crud.delete_login_session(request_session, login_session.session_token_hash) return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}, @@ -108,12 +100,8 @@ class AuthMiddleWare(BaseHTTPMiddleware): status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}, ) - if session_expires_utc - datetime.now(timezone.utc) < timedelta( - seconds=3600 - ): - updated_expiry = datetime.now(timezone.utc) + timedelta( - seconds=settings.settings.session_expiry_seconds - ) + if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600): + updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds) else: updated_expiry = session_expires_utc updated_session: SessionsUpdate = SessionsUpdate( @@ -124,9 +112,7 @@ class AuthMiddleWare(BaseHTTPMiddleware): ) user_id = login_session.user_id request.state.user_id = user_id - crud.update_login_session( - request_session, hashed_token, update_session=updated_session - ) + crud.update_login_session(request_session, hashed_token, update_session=updated_session) except Exception: logger.exception("Failed to authenticate user: \n") return JSONResponse( @@ -160,9 +146,7 @@ def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: return user -def authenticate_user_service( - db_session: Session, user_in: UserLogin -) -> tuple[SessionsCreate, str] | None: +def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None: user = crud.get_user_by_username(db_session, user_in.username) if not user: return None @@ -187,14 +171,10 @@ def authenticate_user_service( # Exchanges service -def create_exchange_service( - db_session: Session, user_id: int, name: str, notes: str | None -) -> ExchangesCreate: +def create_exchange_service(db_session: Session, user_id: int, name: str, notes: str | None) -> ExchangesCreate: existing_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id) if existing_exchange: - raise ExchangeAlreadyExistsError( - "Exchange with the same name already exists for this user" - ) + raise ExchangeAlreadyExistsError("Exchange with the same name already exists for this user") exchange_data = ExchangesCreate( user_id=user_id, name=name, @@ -213,9 +193,7 @@ def create_exchange_service( return exchange_dto -def get_exchanges_by_user_service( - db_session: Session, user_id: int -) -> list[ExchangesRead]: +def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]: exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id) return [ExchangesRead.model_validate(exchange) for exchange in exchanges] @@ -234,22 +212,16 @@ def update_exchanges_service( raise ExchangeNotFoundError("Exchange not found") if name: - other_exchange = crud.get_exchange_by_name_and_user_id( - db_session, name, user_id - ) + other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id) if other_exchange and other_exchange.id != existing_exchange.id: - raise ExchangeAlreadyExistsError( - "Another exchange with the same name already exists for this user" - ) + raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user") exchange_data = ExchangesBase( name=name or existing_exchange.name, notes=notes or existing_exchange.notes, ) try: - exchange = crud.update_exchange( - db_session, cast("int", existing_exchange.id), update_data=exchange_data - ) + exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data) except Exception as e: logger.exception("Failed to update exchange: \n") raise ServiceError("Failed to update exchange") from e @@ -257,9 +229,7 @@ def update_exchanges_service( # Cycle Service -def create_cycle_service( - db_session: Session, user_id: int, cycle_data: CycleBase -) -> CycleRead: +def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead: raise NotImplementedError("Cycle creation not implemented") cycle_data_dict = cycle_data.model_dump() cycle_data_dict["user_id"] = user_id @@ -268,9 +238,7 @@ def create_cycle_service( return CycleRead.model_validate(created_cycle) -def get_cycle_by_id_service( - db_session: Session, user_id: int, cycle_id: int -) -> CycleRead: +def get_cycle_by_id_service(db_session: Session, user_id: int, cycle_id: int) -> CycleRead: cycle = crud.get_cycle_by_id(db_session, cycle_id) if not cycle: raise CycleNotFoundError("Cycle not found") @@ -289,18 +257,12 @@ def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: # return False, "end_date is required when status is CLOSED" if cycle_data.status == "OPEN" and cycle_data.end_date is not None: return False, "end_date must be empty when status is OPEN" - if ( - cycle_data.capital_exposure_cents is not None - and cycle_data.capital_exposure_cents < 0 - ): + if cycle_data.capital_exposure_cents is not None and cycle_data.capital_exposure_cents < 0: return False, "capital_exposure_cents must be non-negative" if ( cycle_data.funding_source is not None and cycle_data.funding_source != "CASH" - and ( - cycle_data.loan_amount_cents is None - or cycle_data.loan_interest_rate_tenth_bps is None - ) + and (cycle_data.loan_amount_cents is None or cycle_data.loan_interest_rate_tenth_bps is None) ): return ( False, @@ -308,10 +270,7 @@ def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: # ) if cycle_data.loan_amount_cents is not None and cycle_data.loan_amount_cents < 0: return False, "loan_amount_cents must be non-negative" - if ( - cycle_data.loan_interest_rate_tenth_bps is not None - and cycle_data.loan_interest_rate_tenth_bps < 0 - ): + if cycle_data.loan_interest_rate_tenth_bps is not None and cycle_data.loan_interest_rate_tenth_bps < 0: return False, "loan_interest_rate_tenth_bps must be non-negative" return True, "" @@ -324,13 +283,9 @@ def _create_cycle_loan_event( ) -> None: now = datetime.now(timezone.utc) today = now.date() - existing_loan_event = crud.get_loan_event_by_cycle_id_and_effective_date( - db_session, cycle_id, today - ) + existing_loan_event = crud.get_loan_event_by_cycle_id_and_effective_date(db_session, cycle_id, today) if existing_loan_event: - raise CycleLoanEventExistsError( - "A loan event with the same effective_date already exists for this cycle." - ) + raise CycleLoanEventExistsError("A loan event with the same effective_date already exists for this cycle.") loan_event_data = CycleLoanChangeEventBase( cycle_id=cycle_id, effective_date=today, @@ -345,9 +300,7 @@ def _create_cycle_loan_event( raise ServiceError("Failed to create cycle loan event") from e -def update_cycle_service( - db_session: Session, user_id: int, cycle_data: CycleUpdate -) -> CycleRead: +def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpdate) -> CycleRead: is_valid, err_msg = _validate_cycle_update_data(cycle_data) if not is_valid: raise InvalidCycleDataError(err_msg) @@ -357,10 +310,7 @@ def update_cycle_service( raise CycleNotFoundError("Cycle not found") if existing_cycle.user_id != user_id: raise CycleNotFoundError("Cycle not found") - if ( - cycle_data.loan_amount_cents is not None - or cycle_data.loan_interest_rate_tenth_bps is not None - ): + if cycle_data.loan_amount_cents is not None or cycle_data.loan_interest_rate_tenth_bps is not None: _create_cycle_loan_event( db_session, cycle_id, @@ -372,16 +322,14 @@ def update_cycle_service( cycle_data_with_user_id: CycleBase = CycleBase.model_validate(provided_data_dict) try: - updated_cycle = crud.update_cycle( - db_session, cycle_id, update_data=cycle_data_with_user_id - ) + updated_cycle = crud.update_cycle(db_session, cycle_id, update_data=cycle_data_with_user_id) except Exception as e: logger.exception("Failed to update cycle: \n") raise ServiceError("Failed to update cycle") from e return CycleRead.model_validate(updated_cycle) -def accural_interest_service(db_session: Session, cycle_id: int) -> None: +def accrual_interest_service(db_session: Session, cycle_id: int) -> None: cycle = crud.get_cycle_by_id(db_session, cycle_id) if not cycle: logger.exception("Cycle not found for interest accrual") @@ -390,9 +338,7 @@ def accural_interest_service(db_session: Session, cycle_id: int) -> None: logger.info("Cycle has no loan, skipping interest accrual") return today = datetime.now(timezone.utc).date() - amount_cents = round( - cycle.loan_amount_cents * cycle.loan_interest_rate_tenth_bps / 100000 / 365 - ) + amount_cents = round(cycle.loan_amount_cents * cycle.loan_interest_rate_tenth_bps / 100000 / 365) try: crud.create_cycle_daily_accrual( db_session, @@ -432,18 +378,13 @@ def _append_cashflows(trade_data: TradeCreate) -> TradeCreate: def _validate_trade_data(trade_data: TradeCreate) -> bool: return not ( - trade_data.trade_type in ("SELL_PUT", "SELL_CALL") - and (trade_data.expiry_date is None or trade_data.strike_price_cents is None) + trade_data.trade_type in ("SELL_PUT", "SELL_CALL") and (trade_data.expiry_date is None or trade_data.strike_price_cents is None) ) -def create_trade_service( - db_session: Session, user_id: int, trade_data: TradeCreate -) -> TradeRead: +def create_trade_service(db_session: Session, user_id: int, trade_data: TradeCreate) -> TradeRead: if not _validate_trade_data(trade_data): - raise InvalidTradeDataError( - "Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades" - ) + raise InvalidTradeDataError("Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades") trade_data_dict = trade_data.model_dump() trade_data_dict["user_id"] = user_id trade_data_with_user_id: TradeCreate = TradeCreate.model_validate(trade_data_dict) @@ -452,9 +393,7 @@ def create_trade_service( return TradeRead.model_validate(created_trade) -def get_trade_by_id_service( - db_session: Session, user_id: int, trade_id: int -) -> TradeRead: +def get_trade_by_id_service(db_session: Session, user_id: int, trade_id: int) -> TradeRead: trade = crud.get_trade_by_id(db_session, trade_id) if not trade: raise TradeNotFoundError("Trade not found") @@ -463,27 +402,21 @@ def get_trade_by_id_service( return TradeRead.model_validate(trade) -def update_trade_friendly_name_service( - db_session: Session, user_id: int, trade_id: int, friendly_name: str -) -> TradeRead: +def update_trade_friendly_name_service(db_session: Session, user_id: int, trade_id: int, friendly_name: str) -> TradeRead: existing_trade = crud.get_trade_by_id(db_session, trade_id) if not existing_trade: raise TradeNotFoundError("Trade not found") if existing_trade.user_id != user_id: raise TradeNotFoundError("Trade not found") try: - updated_trade = crud.update_trade_friendly_name( - db_session, trade_id, friendly_name - ) + updated_trade = crud.update_trade_friendly_name(db_session, trade_id, friendly_name) except Exception as e: logger.exception("Failed to update trade friendly name: \n") raise ServiceError("Failed to update trade friendly name") from e return TradeRead.model_validate(updated_trade) -def update_trade_note_service( - db_session: Session, user_id: int, trade_id: int, note: str | None -) -> TradeRead: +def update_trade_note_service(db_session: Session, user_id: int, trade_id: int, note: str | None) -> TradeRead: existing_trade = crud.get_trade_by_id(db_session, trade_id) if not existing_trade: raise TradeNotFoundError("Trade not found")