From 442da655c0e17dd082872d8a523b70ead2c520b9 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 19 Sep 2025 15:30:41 +0200 Subject: [PATCH 01/18] Fix linting error and linting config --- backend/ruff.toml | 10 +- backend/tests/test_crud.py | 124 ++++++++++++------------ backend/tests/test_db.py | 26 ++--- backend/tests/test_db_migration.py | 31 ++---- backend/tests/test_settings.py | 10 +- backend/trading_journal/crud.py | 68 +++++-------- backend/trading_journal/db.py | 16 +-- backend/trading_journal/db_migration.py | 2 +- backend/trading_journal/models.py | 62 +++--------- backend/trading_journal/models_v1.py | 62 +++--------- 10 files changed, 151 insertions(+), 260 deletions(-) diff --git a/backend/ruff.toml b/backend/ruff.toml index c5a0394..7571904 100644 --- a/backend/ruff.toml +++ b/backend/ruff.toml @@ -13,8 +13,14 @@ ignore = [ "TRY003", "EM101", "EM102", - "PLC0405", + "SIM108", + "C901", + "PLR0912", + "PLR0915", + "PLR0913", + "PLC0415", ] [lint.extend-per-file-ignores] -"test*.py" = ["S101"] +"test*.py" = ["S101", "S105", "S106", "PT011", "PLR2004"] +"models*.py" = ["FA102"] diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index b542d90..e951665 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -1,15 +1,19 @@ -from collections.abc import Generator +from __future__ import annotations + from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING import pytest from sqlalchemy import create_engine -from sqlalchemy.engine import Engine from sqlalchemy.pool import StaticPool from sqlmodel import Session, SQLModel from trading_journal import crud, models -# TODO: If needed, add failing flow tests, but now only add happy flow. +if TYPE_CHECKING: + from collections.abc import Generator + + from sqlalchemy.engine import Engine @pytest.fixture @@ -41,16 +45,14 @@ def make_user(session: Session, username: str = "testuser") -> int: return user.id -def make_cycle( - session: Session, user_id: int, friendly_name: str = "Test Cycle" -) -> int: +def make_cycle(session: Session, user_id: int, friendly_name: str = "Test Cycle") -> int: cycle = models.Cycles( user_id=user_id, friendly_name=friendly_name, symbol="AAPL", underlying_currency=models.UnderlyingCurrency.USD, status=models.CycleStatus.OPEN, - start_date=datetime.now().date(), + start_date=datetime.now(timezone.utc).date(), ) session.add(cycle) session.commit() @@ -58,9 +60,7 @@ def make_cycle( return cycle.id -def make_trade( - session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade" -) -> int: +def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int: trade = models.Trades( user_id=user_id, friendly_name=friendly_name, @@ -68,8 +68,8 @@ def make_trade( underlying_currency=models.UnderlyingCurrency.USD, trade_type=models.TradeType.LONG_SPOT, trade_strategy=models.TradeStrategy.SPOT, - trade_date=datetime.now().date(), - trade_time_utc=datetime.now(), + trade_date=datetime.now(timezone.utc).date(), + trade_time_utc=datetime.now(timezone.utc), quantity=10, price_cents=15000, gross_cash_flow_cents=-150000, @@ -113,7 +113,15 @@ def make_login_session(session: Session, created_at: datetime) -> models.Session return login_session -def test_create_trade_success_with_cycle(session: Session): +def _ensure_utc_aware(dt: datetime) -> datetime | None: + if dt is None: + return None + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) + + +def test_create_trade_success_with_cycle(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id) @@ -124,7 +132,7 @@ def test_create_trade_success_with_cycle(session: Session): "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, - "trade_time_utc": datetime.now(), + "trade_time_utc": datetime.now(timezone.utc), "quantity": 10, "price_cents": 15000, "gross_cash_flow_cents": -150000, @@ -154,7 +162,7 @@ def test_create_trade_success_with_cycle(session: Session): assert actual_trade.cycle_id == trade_data["cycle_id"] -def test_create_trade_with_auto_created_cycle(session: Session): +def test_create_trade_with_auto_created_cycle(session: Session) -> None: user_id = make_user(session) trade_data = { @@ -164,7 +172,7 @@ def test_create_trade_with_auto_created_cycle(session: Session): "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, - "trade_time_utc": datetime.now(), + "trade_time_utc": datetime.now(timezone.utc), "quantity": 5, "price_cents": 15500, } @@ -196,7 +204,7 @@ def test_create_trade_with_auto_created_cycle(session: Session): assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") -def test_create_trade_missing_required_fields(session: Session): +def test_create_trade_missing_required_fields(session: Session) -> None: user_id = make_user(session) base_trade_data = { @@ -206,7 +214,7 @@ def test_create_trade_missing_required_fields(session: Session): "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, - "trade_time_utc": datetime.now(), + "trade_time_utc": datetime.now(timezone.utc), "quantity": 10, "price_cents": 15000, } @@ -254,7 +262,7 @@ def test_create_trade_missing_required_fields(session: Session): assert "price_cents is required" in str(excinfo.value) -def test_get_trade_by_id(session: Session): +def test_get_trade_by_id(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id) trade_data = { @@ -264,8 +272,8 @@ def test_get_trade_by_id(session: Session): "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, - "trade_date": datetime.now().date(), - "trade_time_utc": datetime.now(), + "trade_date": datetime.now(timezone.utc).date(), + "trade_time_utc": datetime.now(timezone.utc), "quantity": 10, "price_cents": 15000, "gross_cash_flow_cents": -150000, @@ -291,7 +299,7 @@ def test_get_trade_by_id(session: Session): assert trade.trade_date == trade_data["trade_date"] -def test_get_trade_by_user_id_and_friendly_name(session: Session): +def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id) friendly_name = "Unique Trade Name" @@ -302,8 +310,8 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session): "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, - "trade_date": datetime.now().date(), - "trade_time_utc": datetime.now(), + "trade_date": datetime.now(timezone.utc).date(), + "trade_time_utc": datetime.now(timezone.utc), "quantity": 10, "price_cents": 15000, "gross_cash_flow_cents": -150000, @@ -318,7 +326,7 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session): assert trade.user_id == user_id -def test_get_trades_by_user_id(session: Session): +def test_get_trades_by_user_id(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id) trade_data_1 = { @@ -328,8 +336,8 @@ def test_get_trades_by_user_id(session: Session): "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, - "trade_date": datetime.now().date(), - "trade_time_utc": datetime.now(), + "trade_date": datetime.now(timezone.utc).date(), + "trade_time_utc": datetime.now(timezone.utc), "quantity": 10, "price_cents": 15000, "gross_cash_flow_cents": -150000, @@ -344,8 +352,8 @@ def test_get_trades_by_user_id(session: Session): "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.SHORT_SPOT, "trade_strategy": models.TradeStrategy.SPOT, - "trade_date": datetime.now().date(), - "trade_time_utc": datetime.now(), + "trade_date": datetime.now(timezone.utc).date(), + "trade_time_utc": datetime.now(timezone.utc), "quantity": 5, "price_cents": 280000, "gross_cash_flow_cents": 1400000, @@ -362,7 +370,7 @@ def test_get_trades_by_user_id(session: Session): assert friendly_names == {"Trade One", "Trade Two"} -def test_update_trade_note(session: Session): +def test_update_trade_note(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id) trade_id = make_trade(session, user_id, cycle_id) @@ -379,7 +387,7 @@ def test_update_trade_note(session: Session): assert actual_trade.notes == new_note -def test_invalidate_trade(session: Session): +def test_invalidate_trade(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id) trade_id = make_trade(session, user_id, cycle_id) @@ -395,7 +403,7 @@ def test_invalidate_trade(session: Session): assert actual_trade.is_invalidated is True -def test_replace_trade(session: Session): +def test_replace_trade(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id) old_trade_id = make_trade(session, user_id, cycle_id) @@ -407,7 +415,7 @@ def test_replace_trade(session: Session): "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, - "trade_time_utc": datetime.now(), + "trade_time_utc": datetime.now(timezone.utc), "quantity": 20, "price_cents": 25000, } @@ -438,7 +446,7 @@ def test_replace_trade(session: Session): assert actual_new_trade.replaced_by_trade_id == old_trade_id -def test_create_cycle(session: Session): +def test_create_cycle(session: Session) -> None: user_id = make_user(session) cycle_data = { "user_id": user_id, @@ -446,7 +454,7 @@ def test_create_cycle(session: Session): "symbol": "GOOGL", "underlying_currency": models.UnderlyingCurrency.USD, "status": models.CycleStatus.OPEN, - "start_date": datetime.now().date(), + "start_date": datetime.now(timezone.utc).date(), } cycle = crud.create_cycle(session, cycle_data) assert cycle.id is not None @@ -467,7 +475,7 @@ def test_create_cycle(session: Session): assert actual_cycle.start_date == cycle_data["start_date"] -def test_update_cycle(session: Session): +def test_update_cycle(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name") @@ -488,7 +496,7 @@ def test_update_cycle(session: Session): assert actual_cycle.status == update_data["status"] -def test_update_cycle_immutable_fields(session: Session): +def test_update_cycle_immutable_fields(session: Session) -> None: user_id = make_user(session) cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name") @@ -496,8 +504,8 @@ def test_update_cycle_immutable_fields(session: Session): update_data = { "id": cycle_id + 1, # Trying to change the ID "user_id": user_id + 1, # Trying to change the user_id - "start_date": datetime(2020, 1, 1).date(), # Trying to change start_date - "created_at": datetime(2020, 1, 1), # Trying to change created_at + "start_date": datetime(2020, 1, 1, tzinfo=timezone.utc).date(), # Trying to change start_date + "created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at "friendly_name": "Valid Update", # Valid field to update } @@ -511,7 +519,7 @@ def test_update_cycle_immutable_fields(session: Session): ) -def test_create_user(session: Session): +def test_create_user(session: Session) -> None: user_data = { "username": "newuser", "password_hash": "newhashedpassword", @@ -528,7 +536,7 @@ def test_create_user(session: Session): assert actual_user.password_hash == user_data["password_hash"] -def test_update_user(session: Session): +def test_update_user(session: Session) -> None: user_id = make_user(session, username="updatableuser") update_data = { @@ -545,14 +553,14 @@ def test_update_user(session: Session): assert actual_user.password_hash == update_data["password_hash"] -def test_update_user_immutable_fields(session: Session): +def test_update_user_immutable_fields(session: Session) -> None: user_id = make_user(session, username="immutableuser") # Attempt to update immutable fields update_data = { "id": user_id + 1, # Trying to change the ID "username": "newusername", # Trying to change the username - "created_at": datetime(2020, 1, 1), # Trying to change created_at + "created_at": datetime(2020, 1, 1, tzinfo=timezone.utc), # Trying to change created_at "password_hash": "validupdate", # Valid field to update } @@ -566,7 +574,7 @@ def test_update_user_immutable_fields(session: Session): # login sessions -def test_create_login_session(session: Session): +def test_create_login_session(session: Session) -> None: user_id = make_user(session, username="testuser") session_token_hash = "sessiontokenhashed" login_session = crud.create_login_session(session, user_id, session_token_hash) @@ -575,7 +583,7 @@ def test_create_login_session(session: Session): assert login_session.session_token_hash == session_token_hash -def test_create_login_session_with_invalid_user(session: Session): +def test_create_login_session_with_invalid_user(session: Session) -> None: invalid_user_id = 9999 # Assuming this user ID does not exist session_token_hash = "sessiontokenhashed" with pytest.raises(ValueError) as excinfo: @@ -583,40 +591,34 @@ def test_create_login_session_with_invalid_user(session: Session): assert "user_id does not exist" in str(excinfo.value) -def test_get_login_session_by_token_and_user_id(session: Session): - now = datetime.now() +def test_get_login_session_by_token_and_user_id(session: Session) -> None: + now = datetime.now(timezone.utc) created_session = make_login_session(session, now) - fetched_session = crud.get_login_session_by_token_hash_and_user_id( - session, created_session.session_token_hash, created_session.user_id - ) + fetched_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id) assert fetched_session is not None assert fetched_session.id == created_session.id assert fetched_session.user_id == created_session.user_id assert fetched_session.session_token_hash == created_session.session_token_hash -def test_update_login_session(session: Session): - now = datetime.now() +def test_update_login_session(session: Session) -> None: + now = datetime.now(timezone.utc) created_session = make_login_session(session, now) update_data = { "last_seen_at": now + timedelta(hours=1), "last_used_ip": "192.168.1.1", } - updated_session = crud.update_login_session( - session, created_session.session_token_hash, update_data - ) + updated_session = crud.update_login_session(session, created_session.session_token_hash, update_data) assert updated_session is not None - assert updated_session.last_seen_at == update_data["last_seen_at"] + assert _ensure_utc_aware(updated_session.last_seen_at) == update_data["last_seen_at"] assert updated_session.last_used_ip == update_data["last_used_ip"] -def test_delete_login_session(session: Session): - now = datetime.now() +def test_delete_login_session(session: Session) -> None: + now = datetime.now(timezone.utc) created_session = make_login_session(session, now) crud.delete_login_session(session, created_session.session_token_hash) - deleted_session = crud.get_login_session_by_token_hash_and_user_id( - session, created_session.session_token_hash, created_session.user_id - ) + deleted_session = crud.get_login_session_by_token_hash_and_user_id(session, created_session.session_token_hash, created_session.user_id) assert deleted_session is None diff --git a/backend/tests/test_db.py b/backend/tests/test_db.py index 4275b1c..e177600 100644 --- a/backend/tests/test_db.py +++ b/backend/tests/test_db.py @@ -46,9 +46,8 @@ def database_ctx(db: Database) -> Generator[Database, None, None]: def test_select_one_executes() -> None: db = create_database(None) # in-memory by default - with database_ctx(db): - with session_ctx(db) as session: - val = session.exec(text("SELECT 1")).scalar_one() + with database_ctx(db), session_ctx(db) as session: + val = session.exec(text("SELECT 1")).scalar_one() assert int(val) == 1 @@ -56,9 +55,7 @@ def test_in_memory_persists_across_sessions_when_using_staticpool() -> None: db = create_database(None) # in-memory with StaticPool with database_ctx(db): with session_ctx(db) as s1: - s1.exec( - text("CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY, val TEXT);") - ) + s1.exec(text("CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY, val TEXT);")) s1.exec(text("INSERT INTO t (val) VALUES (:v)").bindparams(v="hello")) with session_ctx(db) as s2: got = s2.exec(text("SELECT val FROM t")).scalar_one() @@ -67,10 +64,9 @@ def test_in_memory_persists_across_sessions_when_using_staticpool() -> None: def test_sqlite_pragmas_applied() -> None: db = create_database(None) - with database_ctx(db): + with database_ctx(db), session_ctx(db) as session: # PRAGMA returns integer 1 when foreign_keys ON - with session_ctx(db) as session: - fk = session.exec(text("PRAGMA foreign_keys")).scalar_one() + fk = session.exec(text("PRAGMA foreign_keys")).scalar_one() assert int(fk) == 1 @@ -82,16 +78,8 @@ def test_rollback_on_exception() -> None: # Create table then insert and raise inside the same session to force rollback with pytest.raises(RuntimeError): # noqa: PT012, SIM117 with session_ctx(db) as s: - s.exec( - text( - "CREATE TABLE IF NOT EXISTS t_rb (id INTEGER PRIMARY KEY, val TEXT);" - ) - ) - s.exec( - text("INSERT INTO t_rb (val) VALUES (:v)").bindparams( - v="will_rollback" - ) - ) + s.exec(text("CREATE TABLE IF NOT EXISTS t_rb (id INTEGER PRIMARY KEY, val TEXT);")) + s.exec(text("INSERT INTO t_rb (val) VALUES (:v)").bindparams(v="will_rollback")) # simulate handler error -> should trigger rollback in get_session raise RuntimeError("simulated failure") diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index d274e6a..655cf54 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -89,12 +89,10 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: with engine.connect() as conn: # check tables exist rows = conn.execute( - text("SELECT name FROM sqlite_master WHERE type='table'") + text("SELECT name FROM sqlite_master WHERE type='table'"), ).fetchall() found_tables = {r[0] for r in rows} - assert set(expected_schema.keys()).issubset(found_tables), ( - f"missing tables: {set(expected_schema.keys()) - found_tables}" - ) + assert set(expected_schema.keys()).issubset(found_tables), f"missing tables: {set(expected_schema.keys()) - found_tables}" # check user_version uv = conn.execute(text("PRAGMA user_version")).fetchone() @@ -103,14 +101,9 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: # validate each table columns for tbl_name, cols in expected_schema.items(): - info_rows = conn.execute( - text(f"PRAGMA table_info({tbl_name})") - ).fetchall() + info_rows = conn.execute(text(f"PRAGMA table_info({tbl_name})")).fetchall() # map: name -> (type, notnull, pk) - actual = { - r[1]: ((r[2] or "").upper(), int(r[3]), int(r[5])) - for r in info_rows - } + actual = {r[1]: ((r[2] or "").upper(), int(r[3]), int(r[5])) for r in info_rows} for colname, (exp_type, exp_notnull, exp_pk) in cols.items(): assert colname in actual, f"{tbl_name}: missing column {colname}" act_type, act_notnull, act_pk = actual[colname] @@ -122,20 +115,12 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: assert exp_type in act_base or act_base in exp_type, ( f"type mismatch {tbl_name}.{colname}: expected {exp_type}, got {act_base}" ) - assert act_notnull == exp_notnull, ( - f"notnull mismatch {tbl_name}.{colname}: expected {exp_notnull}, got {act_notnull}" - ) - assert act_pk == exp_pk, ( - f"pk mismatch {tbl_name}.{colname}: expected {exp_pk}, got {act_pk}" - ) + assert act_notnull == exp_notnull, f"notnull mismatch {tbl_name}.{colname}: expected {exp_notnull}, got {act_notnull}" + assert act_pk == exp_pk, f"pk mismatch {tbl_name}.{colname}: expected {exp_pk}, got {act_pk}" for tbl_name, fks in expected_fks.items(): - fk_rows = conn.execute( - text(f"PRAGMA foreign_key_list('{tbl_name}')") - ).fetchall() + fk_rows = conn.execute(text(f"PRAGMA foreign_key_list('{tbl_name}')")).fetchall() # fk_rows columns: (id, seq, table, from, to, on_update, on_delete, match) - actual_fk_list = [ - {"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows - ] + actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows] for efk in fks: assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}" finally: diff --git a/backend/tests/test_settings.py b/backend/tests/test_settings.py index 499b703..4866ea1 100644 --- a/backend/tests/test_settings.py +++ b/backend/tests/test_settings.py @@ -12,7 +12,7 @@ def test_default_settings(monkeypatch: pytest.MonkeyPatch) -> None: s = load_settings() assert s.host == "0.0.0.0" # noqa: S104 - assert s.port == 8000 # noqa: PLR2004 + assert s.port == 8000 assert s.workers == 1 assert s.log_level == "info" @@ -26,8 +26,8 @@ def test_env_overrides(monkeypatch: pytest.MonkeyPatch) -> None: s = load_settings() assert s.host == "127.0.0.1" - assert s.port == 9000 # noqa: PLR2004 - assert s.workers == 3 # noqa: PLR2004 + assert s.port == 9000 + assert s.workers == 3 assert s.log_level == "debug" @@ -40,6 +40,6 @@ def test_yaml_config_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> No s = load_settings() assert s.host == "10.0.0.5" - assert s.port == 8088 # noqa: PLR2004 - assert s.workers == 5 # noqa: PLR2004 + assert s.port == 8088 + assert s.workers == 5 assert s.log_level == "debug" diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 5f051d8..0e83bed 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -1,13 +1,18 @@ +from __future__ import annotations + from datetime import datetime, timedelta, timezone -from typing import Mapping +from typing import TYPE_CHECKING from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select from trading_journal import models +if TYPE_CHECKING: + from collections.abc import Mapping -def _check_enum(enum_cls, value, field_name: str): + +def _check_enum(enum_cls: any, value: any, field_name: str) -> any: if value is None: raise ValueError(f"{field_name} is required") # already an enum member @@ -34,19 +39,13 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: raise ValueError("symbol is required") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") - payload["underlying_currency"] = _check_enum( - models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency" - ) + payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") if "trade_type" not in payload: raise ValueError("trade_type is required") - payload["trade_type"] = _check_enum( - models.TradeType, payload["trade_type"], "trade_type" - ) + payload["trade_type"] = _check_enum(models.TradeType, payload["trade_type"], "trade_type") if "trade_strategy" not in payload: raise ValueError("trade_strategy is required") - payload["trade_strategy"] = _check_enum( - models.TradeStrategy, payload["trade_strategy"], "trade_strategy" - ) + payload["trade_strategy"] = _check_enum(models.TradeStrategy, payload["trade_strategy"], "trade_strategy") # trade_time_utc is the creation moment: always set to now (caller shouldn't provide) now = datetime.now(timezone.utc) payload.pop("trade_time_utc", None) @@ -67,9 +66,7 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: if "gross_cash_flow_cents" not in payload: payload["gross_cash_flow_cents"] = -quantity * price_cents if "net_cash_flow_cents" not in payload: - payload["net_cash_flow_cents"] = ( - payload["gross_cash_flow_cents"] - commission_cents - ) + payload["net_cash_flow_cents"] = payload["gross_cash_flow_cents"] - commission_cents # If no cycle_id provided, create Cycle instance but don't call create_cycle() created_cycle = None @@ -78,8 +75,7 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: "user_id": user_id, "symbol": payload["symbol"], "underlying_currency": payload["underlying_currency"], - "friendly_name": "Auto-created Cycle by trade " - + payload.get("friendly_name", ""), + "friendly_name": "Auto-created Cycle by trade " + payload.get("friendly_name", ""), "status": models.CycleStatus.OPEN, "start_date": payload["trade_date"], } @@ -92,9 +88,8 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: cycle = session.get(models.Cycles, cycle_id) if cycle is None: raise ValueError("cycle_id does not exist") - else: - if cycle.user_id != user_id: - raise ValueError("cycle.user_id does not match trade.user_id") + if cycle.user_id != user_id: + raise ValueError("cycle.user_id does not match trade.user_id") # Build trade instance; if we created a Cycle instance, link via relationship so a single flush will persist both and populate ids t_payload = dict(payload) @@ -119,9 +114,7 @@ def get_trade_by_id(session: Session, trade_id: int) -> models.Trades | None: return session.get(models.Trades, trade_id) -def get_trade_by_user_id_and_friendly_name( - session: Session, user_id: int, friendly_name: str -) -> models.Trades | None: +def get_trade_by_user_id_and_friendly_name(session: Session, user_id: int, friendly_name: str) -> models.Trades | None: statement = select(models.Trades).where( models.Trades.user_id == user_id, models.Trades.friendly_name == friendly_name, @@ -169,17 +162,14 @@ def invalidate_trade(session: Session, trade_id: int) -> models.Trades: return trade -def replace_trade( - session: Session, old_trade_id: int, new_trade_data: Mapping -) -> models.Trades: +def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping) -> models.Trades: invalidate_trade(session, old_trade_id) if hasattr(new_trade_data, "dict"): data = new_trade_data.dict(exclude_unset=True) else: data = dict(new_trade_data) data["replaced_by_trade_id"] = old_trade_id - new_trade = create_trade(session, data) - return new_trade + return create_trade(session, data) # Cycles @@ -196,9 +186,7 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: raise ValueError("symbol is required") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") - payload["underlying_currency"] = _check_enum( - models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency" - ) + payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") if "status" not in payload: raise ValueError("status is required") payload["status"] = _check_enum(models.CycleStatus, payload["status"], "status") @@ -219,9 +207,7 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"} -def update_cycle( - session: Session, cycle_id: int, update_data: Mapping -) -> models.Cycles: +def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> models.Cycles: cycle: models.Cycles | None = session.get(models.Cycles, cycle_id) if cycle is None: raise ValueError("cycle_id does not exist") @@ -237,9 +223,9 @@ def update_cycle( if k not in allowed: continue if k == "underlying_currency": - v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") + v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") # noqa: PLW2901 if k == "status": - v = _check_enum(models.CycleStatus, v, "status") + v = _check_enum(models.CycleStatus, v, "status") # noqa: PLW2901 setattr(cycle, k, v) session.add(cycle) try: @@ -337,9 +323,7 @@ def create_login_session( return s -def get_login_session_by_token_hash_and_user_id( - session: Session, session_token_hash: str, user_id: int -) -> models.Sessions | None: +def get_login_session_by_token_hash_and_user_id(session: Session, session_token_hash: str, user_id: int) -> models.Sessions | None: statement = select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hash, models.Sessions.user_id == user_id, @@ -352,14 +336,12 @@ def get_login_session_by_token_hash_and_user_id( IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} -def update_login_session( - session: Session, session_token_hashed: str, update_session: Mapping -) -> models.Sessions | None: +def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping) -> models.Sessions | None: login_session: models.Sessions | None = session.exec( select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hashed, models.Sessions.expires_at > datetime.now(timezone.utc), - ) + ), ).first() if login_session is None: return None @@ -385,7 +367,7 @@ def delete_login_session(session: Session, session_token_hash: str) -> None: login_session: models.Sessions | None = session.exec( select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hash, - ) + ), ).first() if login_session is None: return diff --git a/backend/trading_journal/db.py b/backend/trading_journal/db.py index d09a53d..d952a88 100644 --- a/backend/trading_journal/db.py +++ b/backend/trading_journal/db.py @@ -24,17 +24,13 @@ class Database: ) -> None: self._database_url = database_url or "sqlite:///:memory:" - default_connect = ( - {"check_same_thread": False, "timeout": 30} - if self._database_url.startswith("sqlite") - else {} - ) + default_connect = {"check_same_thread": False, "timeout": 30} if self._database_url.startswith("sqlite") else {} merged_connect = {**default_connect, **(connect_args or {})} if self._database_url == "sqlite:///:memory:": logger = logging.getLogger(__name__) logger.warning( - "Using in-memory SQLite database; all data will be lost when the application stops." + "Using in-memory SQLite database; all data will be lost when the application stops.", ) self._engine = create_engine( self._database_url, @@ -43,15 +39,11 @@ class Database: poolclass=StaticPool, ) else: - self._engine = create_engine( - self._database_url, echo=echo, connect_args=merged_connect - ) + self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect) if self._database_url.startswith("sqlite"): - def _enable_sqlite_pragmas( - dbapi_conn: DBAPIConnection, _connection_record: object - ) -> None: + def _enable_sqlite_pragmas(dbapi_conn: DBAPIConnection, _connection_record: object) -> None: try: cur = dbapi_conn.cursor() cur.execute("PRAGMA journal_mode=WAL;") diff --git a/backend/trading_journal/db_migration.py b/backend/trading_journal/db_migration.py index d55a6a0..2a57464 100644 --- a/backend/trading_journal/db_migration.py +++ b/backend/trading_journal/db_migration.py @@ -60,7 +60,7 @@ def run_migrations(engine: Engine, target_version: int | None = None) -> int: fn = MIGRATIONS.get(cur_version) if fn is None: raise RuntimeError( - f"No migration from {cur_version} -> {cur_version + 1}" + f"No migration from {cur_version} -> {cur_version + 1}", ) # call migration with Engine (fn should use transactions) fn(engine) diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 17de397..259ea3d 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -1,4 +1,4 @@ -from datetime import date, datetime # noqa: TC003 +from datetime import date, datetime from enum import Enum from sqlmodel import ( @@ -65,28 +65,18 @@ class FundingSource(str, Enum): class Trades(SQLModel, table=True): __tablename__ = "trades" - __table_args__ = ( - UniqueConstraint( - "user_id", "friendly_name", name="uq_trades_user_friendly_name" - ), - ) + __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint - friendly_name: str | None = Field( - default=None, sa_column=Column(Text, nullable=True) - ) + friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) - underlying_currency: UnderlyingCurrency = Field( - sa_column=Column(Text, nullable=False) - ) + underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) trade_date: date = Field(sa_column=Column(Date, nullable=False)) - trade_time_utc: datetime = Field( - sa_column=Column(DateTime(timezone=True), nullable=False) - ) + trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int = Field(sa_column=Column(Integer, nullable=False)) @@ -95,36 +85,22 @@ class Trades(SQLModel, table=True): commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) is_invalidated: bool = Field(default=False, nullable=False) - invalidated_at: datetime | None = Field( - default=None, sa_column=Column(DateTime(timezone=True), nullable=True) - ) - replaced_by_trade_id: int | None = Field( - default=None, foreign_key="trades.id", nullable=True - ) + invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True)) + replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - cycle_id: int | None = Field( - default=None, foreign_key="cycles.id", nullable=True, index=True - ) + cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True) cycle: "Cycles" = Relationship(back_populates="trades") class Cycles(SQLModel, table=True): __tablename__ = "cycles" - __table_args__ = ( - UniqueConstraint( - "user_id", "friendly_name", name="uq_cycles_user_friendly_name" - ), - ) + __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) - friendly_name: str | None = Field( - default=None, sa_column=Column(Text, nullable=True) - ) + friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) - underlying_currency: UnderlyingCurrency = Field( - sa_column=Column(Text, nullable=False) - ) + underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) @@ -149,17 +125,9 @@ class Sessions(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) - created_at: datetime = Field( - sa_column=Column(DateTime(timezone=True), nullable=False) - ) - expires_at: datetime = Field( - sa_column=Column(DateTime(timezone=True), nullable=False, index=True) - ) - last_seen_at: datetime | None = Field( - sa_column=Column(DateTime(timezone=True), nullable=True) - ) - last_used_ip: str | None = Field( - default=None, sa_column=Column(Text, nullable=True) - ) + created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True)) + last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True)) + last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 17de397..259ea3d 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -1,4 +1,4 @@ -from datetime import date, datetime # noqa: TC003 +from datetime import date, datetime from enum import Enum from sqlmodel import ( @@ -65,28 +65,18 @@ class FundingSource(str, Enum): class Trades(SQLModel, table=True): __tablename__ = "trades" - __table_args__ = ( - UniqueConstraint( - "user_id", "friendly_name", name="uq_trades_user_friendly_name" - ), - ) + __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint - friendly_name: str | None = Field( - default=None, sa_column=Column(Text, nullable=True) - ) + friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) - underlying_currency: UnderlyingCurrency = Field( - sa_column=Column(Text, nullable=False) - ) + underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) trade_date: date = Field(sa_column=Column(Date, nullable=False)) - trade_time_utc: datetime = Field( - sa_column=Column(DateTime(timezone=True), nullable=False) - ) + trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int = Field(sa_column=Column(Integer, nullable=False)) @@ -95,36 +85,22 @@ class Trades(SQLModel, table=True): commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) is_invalidated: bool = Field(default=False, nullable=False) - invalidated_at: datetime | None = Field( - default=None, sa_column=Column(DateTime(timezone=True), nullable=True) - ) - replaced_by_trade_id: int | None = Field( - default=None, foreign_key="trades.id", nullable=True - ) + invalidated_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True), nullable=True)) + replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) - cycle_id: int | None = Field( - default=None, foreign_key="cycles.id", nullable=True, index=True - ) + cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True) cycle: "Cycles" = Relationship(back_populates="trades") class Cycles(SQLModel, table=True): __tablename__ = "cycles" - __table_args__ = ( - UniqueConstraint( - "user_id", "friendly_name", name="uq_cycles_user_friendly_name" - ), - ) + __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) - friendly_name: str | None = Field( - default=None, sa_column=Column(Text, nullable=True) - ) + friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) - underlying_currency: UnderlyingCurrency = Field( - sa_column=Column(Text, nullable=False) - ) + underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) @@ -149,17 +125,9 @@ class Sessions(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) - created_at: datetime = Field( - sa_column=Column(DateTime(timezone=True), nullable=False) - ) - expires_at: datetime = Field( - sa_column=Column(DateTime(timezone=True), nullable=False, index=True) - ) - last_seen_at: datetime | None = Field( - sa_column=Column(DateTime(timezone=True), nullable=True) - ) - last_used_ip: str | None = Field( - default=None, sa_column=Column(Text, nullable=True) - ) + created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False, index=True)) + last_seen_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True)) + last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) -- 2.49.1 From 76cc967c42c38ce8ec852e8ca4c927427066cfd8 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 19 Sep 2025 23:04:17 +0200 Subject: [PATCH 02/18] cycle and trade add exchange field --- backend/tests/test_crud.py | 18 ++++++++++++++++++ backend/trading_journal/crud.py | 5 +++++ backend/trading_journal/dto.py | 21 +++++++++++++++++++++ backend/trading_journal/models.py | 2 ++ backend/trading_journal/models_v1.py | 2 ++ 5 files changed, 48 insertions(+) create mode 100644 backend/trading_journal/dto.py diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index e951665..2ba27f7 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -50,6 +50,7 @@ def make_cycle(session: Session, user_id: int, friendly_name: str = "Test Cycle" user_id=user_id, friendly_name=friendly_name, symbol="AAPL", + exchange="NASDAQ", underlying_currency=models.UnderlyingCurrency.USD, status=models.CycleStatus.OPEN, start_date=datetime.now(timezone.utc).date(), @@ -65,6 +66,7 @@ def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str user_id=user_id, friendly_name=friendly_name, symbol="AAPL", + exchange="NASDAQ", underlying_currency=models.UnderlyingCurrency.USD, trade_type=models.TradeType.LONG_SPOT, trade_strategy=models.TradeStrategy.SPOT, @@ -129,6 +131,7 @@ def test_create_trade_success_with_cycle(session: Session) -> None: "user_id": user_id, "friendly_name": "Test Trade", "symbol": "AAPL", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -169,6 +172,7 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None: "user_id": user_id, "friendly_name": "Test Trade with Auto Cycle", "symbol": "AAPL", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -211,6 +215,7 @@ def test_create_trade_missing_required_fields(session: Session) -> None: "user_id": user_id, "friendly_name": "Incomplete Trade", "symbol": "AAPL", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -226,6 +231,13 @@ def test_create_trade_missing_required_fields(session: Session) -> None: crud.create_trade(session, trade_data) assert "symbol is required" in str(excinfo.value) + # Missing exchange + trade_data = base_trade_data.copy() + trade_data.pop("exchange", None) + with pytest.raises(ValueError) as excinfo: + crud.create_trade(session, trade_data) + assert "exchange is required" in str(excinfo.value) + # Missing underlying_currency trade_data = base_trade_data.copy() trade_data.pop("underlying_currency", None) @@ -269,6 +281,7 @@ def test_get_trade_by_id(session: Session) -> None: "user_id": user_id, "friendly_name": "Test Trade for Get", "symbol": "AAPL", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -307,6 +320,7 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: "user_id": user_id, "friendly_name": friendly_name, "symbol": "AAPL", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -333,6 +347,7 @@ def test_get_trades_by_user_id(session: Session) -> None: "user_id": user_id, "friendly_name": "Trade One", "symbol": "AAPL", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -349,6 +364,7 @@ def test_get_trades_by_user_id(session: Session) -> None: "user_id": user_id, "friendly_name": "Trade Two", "symbol": "GOOGL", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.SHORT_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -412,6 +428,7 @@ def test_replace_trade(session: Session) -> None: "user_id": user_id, "friendly_name": "Replaced Trade", "symbol": "MSFT", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -452,6 +469,7 @@ def test_create_cycle(session: Session) -> None: "user_id": user_id, "friendly_name": "My First Cycle", "symbol": "GOOGL", + "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "status": models.CycleStatus.OPEN, "start_date": datetime.now(timezone.utc).date(), diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 0e83bed..37973ce 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -37,6 +37,8 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: payload = {k: v for k, v in data.items() if k in allowed} if "symbol" not in payload: raise ValueError("symbol is required") + if "exchange" not in payload: + raise ValueError("exchange is required") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") @@ -74,6 +76,7 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: c_payload = { "user_id": user_id, "symbol": payload["symbol"], + "exchange": payload["exchange"], "underlying_currency": payload["underlying_currency"], "friendly_name": "Auto-created Cycle by trade " + payload.get("friendly_name", ""), "status": models.CycleStatus.OPEN, @@ -184,6 +187,8 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: raise ValueError("user_id is required") if "symbol" not in payload: raise ValueError("symbol is required") + if "exchange" not in payload: + raise ValueError("exchange is required") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py new file mode 100644 index 0000000..b372474 --- /dev/null +++ b/backend/trading_journal/dto.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlmodel import SQLModel + +if TYPE_CHECKING: + from datetime import date, datetime + + from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency + + +class TradeBase(SQLModel): + user_id: int + friendly_name: str | None + symbol: str + underlying_currency: UnderlyingCurrency + trade_type: TradeType + trade_strategy: TradeStrategy + trade_date: date + trade_time_utc: datetime diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 259ea3d..6659361 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -72,6 +72,7 @@ class Trades(SQLModel, table=True): # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) + exchange: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) @@ -100,6 +101,7 @@ class Cycles(SQLModel, table=True): user_id: int = Field(foreign_key="users.id", nullable=False, index=True) friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) + exchange: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 259ea3d..6659361 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -72,6 +72,7 @@ class Trades(SQLModel, table=True): # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) + exchange: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) @@ -100,6 +101,7 @@ class Cycles(SQLModel, table=True): user_id: int = Field(foreign_key="users.id", nullable=False, index=True) friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) + exchange: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) -- 2.49.1 From 1fbc93353d2cf3175989666614dd6029bea69ce3 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 22 Sep 2025 14:33:32 +0200 Subject: [PATCH 03/18] add exchange table --- backend/app.py | 47 +++++----- backend/dev-requirements.txt | 126 +++++++++++++++++++++++++++ backend/models.py | 7 -- backend/requirements.in | 3 +- backend/requirements.txt | 126 +++++++++++++++++++++++++++ backend/settings.py | 1 + backend/tests/test_app.py | 18 ++++ backend/tests/test_crud.py | 76 ++++++++++------ backend/tests/test_db_migration.py | 14 +++ backend/tests/test_main.py | 22 ----- backend/tests/test_security.py | 4 + backend/trading_journal/crud.py | 15 ++-- backend/trading_journal/db.py | 3 +- backend/trading_journal/dto.py | 35 ++++++++ backend/trading_journal/models.py | 15 +++- backend/trading_journal/models_v1.py | 15 +++- backend/trading_journal/security.py | 11 +++ 17 files changed, 446 insertions(+), 92 deletions(-) delete mode 100644 backend/models.py create mode 100644 backend/tests/test_app.py delete mode 100644 backend/tests/test_main.py create mode 100644 backend/tests/test_security.py create mode 100644 backend/trading_journal/security.py diff --git a/backend/app.py b/backend/app.py index 8c0ae4d..812f896 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,33 +1,30 @@ -from fastapi import FastAPI +import asyncio +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager -from models import MsgPayload +from fastapi import FastAPI, status -app = FastAPI() -messages_list: dict[int, MsgPayload] = {} +import settings +from trading_journal import db +from trading_journal.dto import TradeCreate, TradeRead + +API_BASE = "/api/v1" + +_db = db.create_database(settings.settings.database_url) -@app.get("/") -def root() -> dict[str, str]: - return {"message": "Hello"} +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 + await asyncio.to_thread(_db.init_db) + try: + yield + finally: + await asyncio.to_thread(_db.dispose) -# About page route -@app.get("/about") -def about() -> dict[str, str]: - return {"message": "This is the about page."} +app = FastAPI(lifespan=lifespan) -# Route to add a message -@app.post("/messages/{msg_name}/") -def add_msg(msg_name: str) -> dict[str, MsgPayload]: - # Generate an ID for the item based on the highest ID in the messages_list - msg_id = max(messages_list.keys()) + 1 if messages_list else 0 - messages_list[msg_id] = MsgPayload(msg_id=msg_id, msg_name=msg_name) - - return {"message": messages_list[msg_id]} - - -# Route to list all messages -@app.get("/messages") -def message_items() -> dict[str, dict[int, MsgPayload]]: - return {"messages:": messages_list} +@app.get(f"{API_BASE}/status") +async def get_status() -> dict[str, str]: + return {"status": "ok"} diff --git a/backend/dev-requirements.txt b/backend/dev-requirements.txt index 0cab2f5..aebd637 100644 --- a/backend/dev-requirements.txt +++ b/backend/dev-requirements.txt @@ -14,12 +14,130 @@ anyio==4.10.0 \ # via # httpx # starlette +argon2-cffi==25.1.0 \ + --hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \ + --hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741 + # via passlib +argon2-cffi-bindings==25.1.0 \ + --hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \ + --hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \ + --hash=sha256:21378b40e1b8d1655dd5310c84a40fc19a9aa5e6366e835ceb8576bf0fea716d \ + --hash=sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44 \ + --hash=sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a \ + --hash=sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f \ + --hash=sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2 \ + --hash=sha256:5acb4e41090d53f17ca1110c3427f0a130f944b896fc8c83973219c97f57b690 \ + --hash=sha256:5d588dec224e2a83edbdc785a5e6f3c6cd736f46bfd4b441bbb5aa1f5085e584 \ + --hash=sha256:6dca33a9859abf613e22733131fc9194091c1fa7cb3e131c143056b4856aa47e \ + --hash=sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0 \ + --hash=sha256:84a461d4d84ae1295871329b346a97f68eade8c53b6ed9a7ca2d7467f3c8ff6f \ + --hash=sha256:87c33a52407e4c41f3b70a9c2d3f6056d88b10dad7695be708c5021673f55623 \ + --hash=sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b \ + --hash=sha256:a1c70058c6ab1e352304ac7e3b52554daadacd8d453c1752e547c76e9c99ac44 \ + --hash=sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98 \ + --hash=sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500 \ + --hash=sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94 \ + --hash=sha256:b55aec3565b65f56455eebc9b9f34130440404f27fe21c3b375bf1ea4d8fbae6 \ + --hash=sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d \ + --hash=sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85 \ + --hash=sha256:c4f9665de60b1b0e99bcd6be4f17d90339698ce954cfd8d9cf4f91c995165a92 \ + --hash=sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d \ + --hash=sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a \ + --hash=sha256:da0c79c23a63723aa5d782250fbf51b768abca630285262fb5144ba5ae01e520 \ + --hash=sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb + # via argon2-cffi certifi==2025.8.3 \ --hash=sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407 \ --hash=sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5 # via # httpcore # httpx +cffi==2.0.0 \ + --hash=sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb \ + --hash=sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b \ + --hash=sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f \ + --hash=sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9 \ + --hash=sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44 \ + --hash=sha256:0f6084a0ea23d05d20c3edcda20c3d006f9b6f3fefeac38f59262e10cef47ee2 \ + --hash=sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c \ + --hash=sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75 \ + --hash=sha256:1cd13c99ce269b3ed80b417dcd591415d3372bcac067009b6e0f59c7d4015e65 \ + --hash=sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e \ + --hash=sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a \ + --hash=sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e \ + --hash=sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25 \ + --hash=sha256:2081580ebb843f759b9f617314a24ed5738c51d2aee65d31e02f6f7a2b97707a \ + --hash=sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe \ + --hash=sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b \ + --hash=sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91 \ + --hash=sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592 \ + --hash=sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187 \ + --hash=sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c \ + --hash=sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1 \ + --hash=sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94 \ + --hash=sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba \ + --hash=sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb \ + --hash=sha256:3f4d46d8b35698056ec29bca21546e1551a205058ae1a181d871e278b0b28165 \ + --hash=sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529 \ + --hash=sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca \ + --hash=sha256:4647afc2f90d1ddd33441e5b0e85b16b12ddec4fca55f0d9671fef036ecca27c \ + --hash=sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6 \ + --hash=sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c \ + --hash=sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0 \ + --hash=sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743 \ + --hash=sha256:61d028e90346df14fedc3d1e5441df818d095f3b87d286825dfcbd6459b7ef63 \ + --hash=sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5 \ + --hash=sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5 \ + --hash=sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4 \ + --hash=sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d \ + --hash=sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b \ + --hash=sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93 \ + --hash=sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205 \ + --hash=sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27 \ + --hash=sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512 \ + --hash=sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d \ + --hash=sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c \ + --hash=sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037 \ + --hash=sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26 \ + --hash=sha256:89472c9762729b5ae1ad974b777416bfda4ac5642423fa93bd57a09204712322 \ + --hash=sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb \ + --hash=sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c \ + --hash=sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8 \ + --hash=sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4 \ + --hash=sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414 \ + --hash=sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9 \ + --hash=sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664 \ + --hash=sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9 \ + --hash=sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775 \ + --hash=sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739 \ + --hash=sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc \ + --hash=sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062 \ + --hash=sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe \ + --hash=sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9 \ + --hash=sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92 \ + --hash=sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5 \ + --hash=sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13 \ + --hash=sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d \ + --hash=sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26 \ + --hash=sha256:cb527a79772e5ef98fb1d700678fe031e353e765d1ca2d409c92263c6d43e09f \ + --hash=sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495 \ + --hash=sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b \ + --hash=sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6 \ + --hash=sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c \ + --hash=sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef \ + --hash=sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5 \ + --hash=sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18 \ + --hash=sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad \ + --hash=sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3 \ + --hash=sha256:de8dad4425a6ca6e4e5e297b27b5c824ecc7581910bf9aee86cb6835e6812aa7 \ + --hash=sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5 \ + --hash=sha256:e6e73b9e02893c764e7e8d5bb5ce277f1a009cd5243f8228f75f842bf937c534 \ + --hash=sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49 \ + --hash=sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2 \ + --hash=sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5 \ + --hash=sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453 \ + --hash=sha256:fe562eb1a64e67dd297ccc4f5addea2501664954f2692b69a76449ec7913ecbf + # via argon2-cffi-bindings click==8.2.1 \ --hash=sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202 \ --hash=sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b @@ -112,10 +230,18 @@ packaging==25.0 \ --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f # via pytest +passlib[argon2]==1.7.4 \ + --hash=sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1 \ + --hash=sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04 + # via -r requirements.in pluggy==1.6.0 \ --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 # via pytest +pycparser==2.23 \ + --hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \ + --hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934 + # via cffi pydantic==2.11.7 \ --hash=sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db \ --hash=sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b diff --git a/backend/models.py b/backend/models.py deleted file mode 100644 index ca1a882..0000000 --- a/backend/models.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Optional -from pydantic import BaseModel - - -class MsgPayload(BaseModel): - msg_id: Optional[int] - msg_name: str diff --git a/backend/requirements.in b/backend/requirements.in index 76cec58..dc105e4 100644 --- a/backend/requirements.in +++ b/backend/requirements.in @@ -3,4 +3,5 @@ uvicorn httpx pyyaml pydantic-settings -sqlmodel \ No newline at end of file +sqlmodel +passlib[argon2] \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt index 18a2164..6131b3f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,12 +14,130 @@ anyio==4.10.0 \ # via # httpx # starlette +argon2-cffi==25.1.0 \ + --hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \ + --hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741 + # via passlib +argon2-cffi-bindings==25.1.0 \ + --hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \ + --hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \ + --hash=sha256:21378b40e1b8d1655dd5310c84a40fc19a9aa5e6366e835ceb8576bf0fea716d \ + --hash=sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44 \ + --hash=sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a \ + --hash=sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f \ + --hash=sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2 \ + --hash=sha256:5acb4e41090d53f17ca1110c3427f0a130f944b896fc8c83973219c97f57b690 \ + --hash=sha256:5d588dec224e2a83edbdc785a5e6f3c6cd736f46bfd4b441bbb5aa1f5085e584 \ + --hash=sha256:6dca33a9859abf613e22733131fc9194091c1fa7cb3e131c143056b4856aa47e \ + --hash=sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0 \ + --hash=sha256:84a461d4d84ae1295871329b346a97f68eade8c53b6ed9a7ca2d7467f3c8ff6f \ + --hash=sha256:87c33a52407e4c41f3b70a9c2d3f6056d88b10dad7695be708c5021673f55623 \ + --hash=sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b \ + --hash=sha256:a1c70058c6ab1e352304ac7e3b52554daadacd8d453c1752e547c76e9c99ac44 \ + --hash=sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98 \ + --hash=sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500 \ + --hash=sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94 \ + --hash=sha256:b55aec3565b65f56455eebc9b9f34130440404f27fe21c3b375bf1ea4d8fbae6 \ + --hash=sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d \ + --hash=sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85 \ + --hash=sha256:c4f9665de60b1b0e99bcd6be4f17d90339698ce954cfd8d9cf4f91c995165a92 \ + --hash=sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d \ + --hash=sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a \ + --hash=sha256:da0c79c23a63723aa5d782250fbf51b768abca630285262fb5144ba5ae01e520 \ + --hash=sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb + # via argon2-cffi certifi==2025.8.3 \ --hash=sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407 \ --hash=sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5 # via # httpcore # httpx +cffi==2.0.0 \ + --hash=sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb \ + --hash=sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b \ + --hash=sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f \ + --hash=sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9 \ + --hash=sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44 \ + --hash=sha256:0f6084a0ea23d05d20c3edcda20c3d006f9b6f3fefeac38f59262e10cef47ee2 \ + --hash=sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c \ + --hash=sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75 \ + --hash=sha256:1cd13c99ce269b3ed80b417dcd591415d3372bcac067009b6e0f59c7d4015e65 \ + --hash=sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e \ + --hash=sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a \ + --hash=sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e \ + --hash=sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25 \ + --hash=sha256:2081580ebb843f759b9f617314a24ed5738c51d2aee65d31e02f6f7a2b97707a \ + --hash=sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe \ + --hash=sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b \ + --hash=sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91 \ + --hash=sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592 \ + --hash=sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187 \ + --hash=sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c \ + --hash=sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1 \ + --hash=sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94 \ + --hash=sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba \ + --hash=sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb \ + --hash=sha256:3f4d46d8b35698056ec29bca21546e1551a205058ae1a181d871e278b0b28165 \ + --hash=sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529 \ + --hash=sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca \ + --hash=sha256:4647afc2f90d1ddd33441e5b0e85b16b12ddec4fca55f0d9671fef036ecca27c \ + --hash=sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6 \ + --hash=sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c \ + --hash=sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0 \ + --hash=sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743 \ + --hash=sha256:61d028e90346df14fedc3d1e5441df818d095f3b87d286825dfcbd6459b7ef63 \ + --hash=sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5 \ + --hash=sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5 \ + --hash=sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4 \ + --hash=sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d \ + --hash=sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b \ + --hash=sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93 \ + --hash=sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205 \ + --hash=sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27 \ + --hash=sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512 \ + --hash=sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d \ + --hash=sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c \ + --hash=sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037 \ + --hash=sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26 \ + --hash=sha256:89472c9762729b5ae1ad974b777416bfda4ac5642423fa93bd57a09204712322 \ + --hash=sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb \ + --hash=sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c \ + --hash=sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8 \ + --hash=sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4 \ + --hash=sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414 \ + --hash=sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9 \ + --hash=sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664 \ + --hash=sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9 \ + --hash=sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775 \ + --hash=sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739 \ + --hash=sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc \ + --hash=sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062 \ + --hash=sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe \ + --hash=sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9 \ + --hash=sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92 \ + --hash=sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5 \ + --hash=sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13 \ + --hash=sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d \ + --hash=sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26 \ + --hash=sha256:cb527a79772e5ef98fb1d700678fe031e353e765d1ca2d409c92263c6d43e09f \ + --hash=sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495 \ + --hash=sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b \ + --hash=sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6 \ + --hash=sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c \ + --hash=sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef \ + --hash=sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5 \ + --hash=sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18 \ + --hash=sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad \ + --hash=sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3 \ + --hash=sha256:de8dad4425a6ca6e4e5e297b27b5c824ecc7581910bf9aee86cb6835e6812aa7 \ + --hash=sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5 \ + --hash=sha256:e6e73b9e02893c764e7e8d5bb5ce277f1a009cd5243f8228f75f842bf937c534 \ + --hash=sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49 \ + --hash=sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2 \ + --hash=sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5 \ + --hash=sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453 \ + --hash=sha256:fe562eb1a64e67dd297ccc4f5addea2501664954f2692b69a76449ec7913ecbf + # via argon2-cffi-bindings click==8.2.1 \ --hash=sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202 \ --hash=sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b @@ -104,6 +222,14 @@ idna==3.10 \ # via # anyio # httpx +passlib[argon2]==1.7.4 \ + --hash=sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1 \ + --hash=sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04 + # via -r requirements.in +pycparser==2.23 \ + --hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \ + --hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934 + # via cffi pydantic==2.11.7 \ --hash=sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db \ --hash=sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b diff --git a/backend/settings.py b/backend/settings.py index 2096af8..25ad0dc 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -12,6 +12,7 @@ class Settings(BaseSettings): port: int = 8000 workers: int = 1 log_level: str = "info" + database_url: str = "sqlite:///:memory:" model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") diff --git a/backend/tests/test_app.py b/backend/tests/test_app.py new file mode 100644 index 0000000..d6123a6 --- /dev/null +++ b/backend/tests/test_app.py @@ -0,0 +1,18 @@ +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from app import API_BASE, app + + +@pytest.fixture +def client() -> Generator[TestClient, None, None]: + with TestClient(app) as client: + yield client + + +def test_get_status(client: TestClient) -> None: + response = client.get(f"{API_BASE}/status") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 2ba27f7..5eac231 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -33,8 +33,11 @@ def engine() -> Generator[Engine, None, None]: @pytest.fixture def session(engine: Engine) -> Generator[Session, None, None]: - with Session(engine) as s: - yield s + session = Session(engine) + try: + yield session + finally: + session.close() def make_user(session: Session, username: str = "testuser") -> int: @@ -45,12 +48,20 @@ def make_user(session: Session, username: str = "testuser") -> int: return user.id -def make_cycle(session: Session, user_id: int, friendly_name: str = "Test Cycle") -> int: +def make_exchange(session: Session, name: str = "NASDAQ") -> int: + exchange = models.Exchanges(name=name, notes="Test exchange") + session.add(exchange) + session.commit() + session.refresh(exchange) + return exchange.id + + +def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int: cycle = models.Cycles( user_id=user_id, friendly_name=friendly_name, symbol="AAPL", - exchange="NASDAQ", + exchange_id=exchange_id, underlying_currency=models.UnderlyingCurrency.USD, status=models.CycleStatus.OPEN, start_date=datetime.now(timezone.utc).date(), @@ -62,11 +73,13 @@ def make_cycle(session: Session, user_id: int, friendly_name: str = "Test Cycle" def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int: + cycle: models.Cycles = session.get(models.Cycles, cycle_id) + exchange_id = cycle.exchange_id trade = models.Trades( user_id=user_id, friendly_name=friendly_name, symbol="AAPL", - exchange="NASDAQ", + exchange_id=exchange_id, underlying_currency=models.UnderlyingCurrency.USD, trade_type=models.TradeType.LONG_SPOT, trade_strategy=models.TradeStrategy.SPOT, @@ -125,13 +138,13 @@ def _ensure_utc_aware(dt: datetime) -> datetime | None: def test_create_trade_success_with_cycle(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id) + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id) trade_data = { "user_id": user_id, "friendly_name": "Test Trade", "symbol": "AAPL", - "exchange": "NASDAQ", "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -167,12 +180,13 @@ def test_create_trade_success_with_cycle(session: Session) -> None: def test_create_trade_with_auto_created_cycle(session: Session) -> None: user_id = make_user(session) + exchange_id = make_exchange(session) trade_data = { "user_id": user_id, "friendly_name": "Test Trade with Auto Cycle", "symbol": "AAPL", - "exchange": "NASDAQ", + "exchange_id": exchange_id, "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -210,12 +224,13 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None: def test_create_trade_missing_required_fields(session: Session) -> None: user_id = make_user(session) + exchange_id = make_exchange(session) base_trade_data = { "user_id": user_id, "friendly_name": "Incomplete Trade", "symbol": "AAPL", - "exchange": "NASDAQ", + "exchange_id": exchange_id, "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -231,12 +246,12 @@ def test_create_trade_missing_required_fields(session: Session) -> None: crud.create_trade(session, trade_data) assert "symbol is required" in str(excinfo.value) - # Missing exchange + # Missing exchange and cycle together trade_data = base_trade_data.copy() - trade_data.pop("exchange", None) + trade_data.pop("exchange_id", None) with pytest.raises(ValueError) as excinfo: crud.create_trade(session, trade_data) - assert "exchange is required" in str(excinfo.value) + assert "exchange_id is required when no cycle is attached" in str(excinfo.value) # Missing underlying_currency trade_data = base_trade_data.copy() @@ -276,12 +291,13 @@ def test_create_trade_missing_required_fields(session: Session) -> None: def test_get_trade_by_id(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id) + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id) trade_data = { "user_id": user_id, "friendly_name": "Test Trade for Get", "symbol": "AAPL", - "exchange": "NASDAQ", + "exchange_id": exchange_id, "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -314,13 +330,14 @@ def test_get_trade_by_id(session: Session) -> None: def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id) + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id) friendly_name = "Unique Trade Name" trade_data = { "user_id": user_id, "friendly_name": friendly_name, "symbol": "AAPL", - "exchange": "NASDAQ", + "exchange_id": exchange_id, "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -342,12 +359,13 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: def test_get_trades_by_user_id(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id) + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id) trade_data_1 = { "user_id": user_id, "friendly_name": "Trade One", "symbol": "AAPL", - "exchange": "NASDAQ", + "exchange_id": exchange_id, "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -364,7 +382,7 @@ def test_get_trades_by_user_id(session: Session) -> None: "user_id": user_id, "friendly_name": "Trade Two", "symbol": "GOOGL", - "exchange": "NASDAQ", + "exchange_id": exchange_id, "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.SHORT_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -388,7 +406,8 @@ def test_get_trades_by_user_id(session: Session) -> None: def test_update_trade_note(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id) + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id) trade_id = make_trade(session, user_id, cycle_id) new_note = "This is an updated note." @@ -405,7 +424,8 @@ def test_update_trade_note(session: Session) -> None: def test_invalidate_trade(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id) + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id) trade_id = make_trade(session, user_id, cycle_id) invalidated_trade = crud.invalidate_trade(session, trade_id) @@ -421,14 +441,15 @@ def test_invalidate_trade(session: Session) -> None: def test_replace_trade(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id) + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id) old_trade_id = make_trade(session, user_id, cycle_id) new_trade_data = { "user_id": user_id, "friendly_name": "Replaced Trade", "symbol": "MSFT", - "exchange": "NASDAQ", + "exchange_id": exchange_id, "underlying_currency": models.UnderlyingCurrency.USD, "trade_type": models.TradeType.LONG_SPOT, "trade_strategy": models.TradeStrategy.SPOT, @@ -465,11 +486,12 @@ def test_replace_trade(session: Session) -> None: def test_create_cycle(session: Session) -> None: user_id = make_user(session) + exchange_id = make_exchange(session) cycle_data = { "user_id": user_id, "friendly_name": "My First Cycle", "symbol": "GOOGL", - "exchange": "NASDAQ", + "exchange_id": exchange_id, "underlying_currency": models.UnderlyingCurrency.USD, "status": models.CycleStatus.OPEN, "start_date": datetime.now(timezone.utc).date(), @@ -495,7 +517,8 @@ def test_create_cycle(session: Session) -> None: def test_update_cycle(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name") + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") update_data = { "friendly_name": "Updated Cycle Name", @@ -516,7 +539,8 @@ def test_update_cycle(session: Session) -> None: def test_update_cycle_immutable_fields(session: Session) -> None: user_id = make_user(session) - cycle_id = make_cycle(session, user_id, friendly_name="Initial Cycle Name") + exchange_id = make_exchange(session) + cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") # Attempt to update immutable fields update_data = { diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index 655cf54..15c7fba 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -36,6 +36,7 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "user_id": ("INTEGER", 1, 0), "friendly_name": ("TEXT", 0, 0), "symbol": ("TEXT", 1, 0), + "exchange_id": ("INTEGER", 1, 0), "underlying_currency": ("TEXT", 1, 0), "status": ("TEXT", 1, 0), "funding_source": ("TEXT", 0, 0), @@ -50,9 +51,11 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "user_id": ("INTEGER", 1, 0), "friendly_name": ("TEXT", 0, 0), "symbol": ("TEXT", 1, 0), + "exchange_id": ("INTEGER", 1, 0), "underlying_currency": ("TEXT", 1, 0), "trade_type": ("TEXT", 1, 0), "trade_strategy": ("TEXT", 1, 0), + "trade_date": ("DATE", 1, 0), "trade_time_utc": ("DATETIME", 1, 0), "expiry_date": ("DATE", 0, 0), "strike_price_cents": ("INTEGER", 0, 0), @@ -61,6 +64,10 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "gross_cash_flow_cents": ("INTEGER", 1, 0), "commission_cents": ("INTEGER", 1, 0), "net_cash_flow_cents": ("INTEGER", 1, 0), + "is_invalidated": ("BOOLEAN", 1, 0), + "invalidated_at": ("DATETIME", 0, 0), + "replaced_by_trade_id": ("INTEGER", 0, 0), + "notes": ("TEXT", 0, 0), "cycle_id": ("INTEGER", 0, 0), }, "sessions": { @@ -80,10 +87,17 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "trades": [ {"table": "cycles", "from": "cycle_id", "to": "id"}, {"table": "users", "from": "user_id", "to": "id"}, + {"table": "exchanges", "from": "exchange_id", "to": "id"}, ], "cycles": [ {"table": "users", "from": "user_id", "to": "id"}, + {"table": "exchanges", "from": "exchange_id", "to": "id"}, ], + "sessions": [ + {"table": "users", "from": "user_id", "to": "id"}, + ], + "users": [], + "exchanges": [], } with engine.connect() as conn: diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py deleted file mode 100644 index b4dd74c..0000000 --- a/backend/tests/test_main.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest -from fastapi.testclient import TestClient - -from app import app - - -@pytest.fixture -def client(): - with TestClient(app) as client: - yield client - - -def test_home_route(client): - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"message": "Hello"} - - -def test_about_route(client): - response = client.get("/about") - assert response.status_code == 200 - assert response.json() == {"message": "This is the about page."} diff --git a/backend/tests/test_security.py b/backend/tests/test_security.py new file mode 100644 index 0000000..cab62d7 --- /dev/null +++ b/backend/tests/test_security.py @@ -0,0 +1,4 @@ +from trading_journal import security + +def test_hash_password() -> None: + plain = "password" diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 37973ce..5ce6fb4 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -35,10 +35,11 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: data = dict(trade_data) allowed = {c.name for c in models.Trades.__table__.columns} payload = {k: v for k, v in data.items() if k in allowed} + cycle_id = payload.get("cycle_id") if "symbol" not in payload: raise ValueError("symbol is required") - if "exchange" not in payload: - raise ValueError("exchange is required") + if "exchange_id" not in payload and cycle_id is None: + raise ValueError("exchange_id is required when no cycle is attached") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") @@ -54,7 +55,6 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: payload["trade_time_utc"] = now if "trade_date" not in payload or payload.get("trade_date") is None: payload["trade_date"] = payload["trade_time_utc"].date() - cycle_id = payload.get("cycle_id") user_id = payload.get("user_id") if "quantity" not in payload: raise ValueError("quantity is required") @@ -76,7 +76,7 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: c_payload = { "user_id": user_id, "symbol": payload["symbol"], - "exchange": payload["exchange"], + "exchange_id": payload["exchange_id"], "underlying_currency": payload["underlying_currency"], "friendly_name": "Auto-created Cycle by trade " + payload.get("friendly_name", ""), "status": models.CycleStatus.OPEN, @@ -89,8 +89,11 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: # If cycle_id provided, validate existence and ownership if cycle_id is not None: cycle = session.get(models.Cycles, cycle_id) + if cycle is None: raise ValueError("cycle_id does not exist") + payload.pop("exchange_id", None) # ignore exchange_id if provided; use cycle's exchange_id + payload["exchange_id"] = cycle.exchange_id if cycle.user_id != user_id: raise ValueError("cycle.user_id does not match trade.user_id") @@ -187,8 +190,8 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: raise ValueError("user_id is required") if "symbol" not in payload: raise ValueError("symbol is required") - if "exchange" not in payload: - raise ValueError("exchange is required") + if "exchange_id" not in payload: + raise ValueError("exchange_id is required") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") diff --git a/backend/trading_journal/db.py b/backend/trading_journal/db.py index d952a88..039727c 100644 --- a/backend/trading_journal/db.py +++ b/backend/trading_journal/db.py @@ -58,7 +58,8 @@ class Database: event.listen(self._engine, "connect", _enable_sqlite_pragmas) def init_db(self) -> None: - db_migration.run_migrations(self._engine) + # db_migration.run_migrations(self._engine) + pass def get_session(self) -> Generator[Session, None, None]: session = Session(self._engine) diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index b372474..1a9d478 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -14,8 +14,43 @@ class TradeBase(SQLModel): user_id: int friendly_name: str | None symbol: str + exchange: str underlying_currency: UnderlyingCurrency trade_type: TradeType trade_strategy: TradeStrategy trade_date: date trade_time_utc: datetime + quantity: int + price_cents: int + gross_cash_flow_cents: int + commission_cents: int + net_cash_flow_cents: int + notes: str | None + cycle_id: int | None = None + + +class TradeCreate(TradeBase): + expiry_date: date | None = None + strike_price_cents: int | None = None + is_invalidated: bool = False + invalidated_at: datetime | None = None + replaced_by_trade_id: int | None = None + + +class TradeRead(TradeBase): + id: int + is_invalidated: bool + invalidated_at: datetime | None + + +class UserBase(SQLModel): + username: str + is_active: bool = True + + +class UserCreate(UserBase): + password: str + + +class UserRead(UserBase): + id: int diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 6659361..0238a81 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -72,7 +72,8 @@ class Trades(SQLModel, table=True): # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) - exchange: str = Field(sa_column=Column(Text, nullable=False)) + exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True) + exchange: "Exchanges" = Relationship(back_populates="trades") underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) @@ -101,7 +102,8 @@ class Cycles(SQLModel, table=True): user_id: int = Field(foreign_key="users.id", nullable=False, index=True) friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) - exchange: str = Field(sa_column=Column(Text, nullable=False)) + exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True) + exchange: "Exchanges" = Relationship(back_populates="cycles") underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) @@ -113,6 +115,15 @@ class Cycles(SQLModel, table=True): trades: list["Trades"] = Relationship(back_populates="cycle") +class Exchanges(SQLModel, table=True): + __tablename__ = "exchanges" + id: int | None = Field(default=None, primary_key=True) + name: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + trades: list["Trades"] = Relationship(back_populates="exchange") + cycles: list["Cycles"] = Relationship(back_populates="exchange") + + class Users(SQLModel, table=True): __tablename__ = "users" id: int | None = Field(default=None, primary_key=True) diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 6659361..0238a81 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -72,7 +72,8 @@ class Trades(SQLModel, table=True): # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) - exchange: str = Field(sa_column=Column(Text, nullable=False)) + exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True) + exchange: "Exchanges" = Relationship(back_populates="trades") underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) @@ -101,7 +102,8 @@ class Cycles(SQLModel, table=True): user_id: int = Field(foreign_key="users.id", nullable=False, index=True) friendly_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) symbol: str = Field(sa_column=Column(Text, nullable=False)) - exchange: str = Field(sa_column=Column(Text, nullable=False)) + exchange_id: int = Field(foreign_key="exchanges.id", nullable=False, index=True) + exchange: "Exchanges" = Relationship(back_populates="cycles") underlying_currency: UnderlyingCurrency = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) @@ -113,6 +115,15 @@ class Cycles(SQLModel, table=True): trades: list["Trades"] = Relationship(back_populates="cycle") +class Exchanges(SQLModel, table=True): + __tablename__ = "exchanges" + id: int | None = Field(default=None, primary_key=True) + name: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + trades: list["Trades"] = Relationship(back_populates="exchange") + cycles: list["Cycles"] = Relationship(back_populates="exchange") + + class Users(SQLModel, table=True): __tablename__ = "users" id: int | None = Field(default=None, primary_key=True) diff --git a/backend/trading_journal/security.py b/backend/trading_journal/security.py new file mode 100644 index 0000000..edb32ce --- /dev/null +++ b/backend/trading_journal/security.py @@ -0,0 +1,11 @@ +from passlib.context import CryptContext + +pwd_ctx = CryptContext(schemes=["argon2"], deprecated="auto") + + +def hash_password(plain: str) -> str: + return pwd_ctx.hash(plain) + +def verify_password(plain: str, hashed: str) -> bool: + return pwd_ctx.verify(plain, hashed) + -- 2.49.1 From 76ed38e9af6cdf5287d9a4488e46b72457aae3ac Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 22 Sep 2025 14:39:33 +0200 Subject: [PATCH 04/18] add crud for exchange --- backend/tests/test_crud.py | 73 +++++++++++++++++++++++++++++++ backend/trading_journal/crud.py | 77 +++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 5eac231..1ae5a55 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -561,6 +561,79 @@ def test_update_cycle_immutable_fields(session: Session) -> None: ) +# Exchanges +def test_create_exchange(session: Session) -> None: + exchange_data = { + "name": "NYSE", + "notes": "New York Stock Exchange", + } + exchange = crud.create_exchange(session, exchange_data) + assert exchange.id is not None + assert exchange.name == exchange_data["name"] + assert exchange.notes == exchange_data["notes"] + + session.refresh(exchange) + actual_exchange = session.get(models.Exchanges, exchange.id) + assert actual_exchange is not None + assert actual_exchange.name == exchange_data["name"] + assert actual_exchange.notes == exchange_data["notes"] + + +def test_get_exchange_by_id(session: Session) -> None: + exchange_id = make_exchange(session, name="LSE") + exchange = crud.get_exchange_by_id(session, exchange_id) + assert exchange is not None + assert exchange.id == exchange_id + assert exchange.name == "LSE" + + +def test_get_exchange_by_name(session: Session) -> None: + exchange_name = "TSX" + make_exchange(session, name=exchange_name) + exchange = crud.get_exchange_by_name(session, exchange_name) + assert exchange is not None + assert exchange.name == exchange_name + + +def test_get_all_exchanges(session: Session) -> None: + exchange_names = ["NYSE", "NASDAQ", "LSE"] + for name in exchange_names: + make_exchange(session, name=name) + + exchanges = crud.get_all_exchanges(session) + assert len(exchanges) >= 3 + fetched_names = {ex.name for ex in exchanges} + for name in exchange_names: + assert name in fetched_names + + +def test_update_exchange(session: Session) -> None: + exchange_id = make_exchange(session, name="Initial Exchange") + update_data = { + "name": "Updated Exchange", + "notes": "Updated notes for the exchange", + } + updated_exchange = crud.update_exchange(session, exchange_id, update_data) + assert updated_exchange is not None + assert updated_exchange.id == exchange_id + assert updated_exchange.name == update_data["name"] + assert updated_exchange.notes == update_data["notes"] + + session.refresh(updated_exchange) + actual_exchange = session.get(models.Exchanges, exchange_id) + assert actual_exchange is not None + assert actual_exchange.name == update_data["name"] + assert actual_exchange.notes == update_data["notes"] + + +def test_delete_exchange(session: Session) -> None: + exchange_id = make_exchange(session, name="Deletable Exchange") + crud.delete_exchange(session, exchange_id) + deleted_exchange = session.get(models.Exchanges, exchange_id) + assert deleted_exchange is None + + +# Users def test_create_user(session: Session) -> None: user_data = { "username": "newuser", diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 5ce6fb4..9e998bd 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -245,6 +245,83 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> model return cycle +# Exchanges +IMMUTABLE_EXCHANGE_FIELDS = {"id"} + + +def create_exchange(session: Session, exchange_data: Mapping) -> models.Exchanges: + if hasattr(exchange_data, "dict"): + data = exchange_data.dict(exclude_unset=True) + else: + data = dict(exchange_data) + allowed = {c.name for c in models.Exchanges.__table__.columns} + payload = {k: v for k, v in data.items() if k in allowed} + if "name" not in payload: + raise ValueError("name is required") + + e = models.Exchanges(**payload) + session.add(e) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("create_exchange integrity error") from e + session.refresh(e) + return e + + +def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | None: + return session.get(models.Exchanges, exchange_id) + + +def get_exchange_by_name(session: Session, name: str) -> models.Exchanges | None: + statement = select(models.Exchanges).where( + models.Exchanges.name == name, + ) + return session.exec(statement).first() + + +def get_all_exchanges(session: Session) -> list[models.Exchanges]: + statement = select(models.Exchanges) + return session.exec(statement).all() + + +def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges: + exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) + if exchange is None: + raise ValueError("exchange_id does not exist") + if hasattr(update_data, "dict"): + data = update_data.dict(exclude_unset=True) + else: + data = dict(update_data) + allowed = {c.name for c in models.Exchanges.__table__.columns} + for k, v in data.items(): + if k in IMMUTABLE_EXCHANGE_FIELDS: + raise ValueError(f"field {k!r} is immutable") + if k in allowed: + setattr(exchange, k, v) + session.add(exchange) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_exchange integrity error") from e + session.refresh(exchange) + return exchange + + +def delete_exchange(session: Session, exchange_id: int) -> None: + exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) + if exchange is None: + return + session.delete(exchange) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("delete_exchange integrity error") from e + + # Users IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"} -- 2.49.1 From e70a63e4f96ae92f1634f730084f6225e1227aa9 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 22 Sep 2025 14:54:29 +0200 Subject: [PATCH 05/18] add security py --- backend/dev-requirements.txt | 6 +--- backend/requirements.in | 2 +- backend/requirements.txt | 6 +--- backend/settings.py | 3 ++ backend/tests/test_security.py | 22 ++++++++++++- backend/trading_journal/security.py | 48 ++++++++++++++++++++++++++--- 6 files changed, 71 insertions(+), 16 deletions(-) diff --git a/backend/dev-requirements.txt b/backend/dev-requirements.txt index aebd637..dd46058 100644 --- a/backend/dev-requirements.txt +++ b/backend/dev-requirements.txt @@ -17,7 +17,7 @@ anyio==4.10.0 \ argon2-cffi==25.1.0 \ --hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \ --hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741 - # via passlib + # via -r requirements.in argon2-cffi-bindings==25.1.0 \ --hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \ --hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \ @@ -230,10 +230,6 @@ packaging==25.0 \ --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f # via pytest -passlib[argon2]==1.7.4 \ - --hash=sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1 \ - --hash=sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04 - # via -r requirements.in pluggy==1.6.0 \ --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 diff --git a/backend/requirements.in b/backend/requirements.in index dc105e4..0eafd8e 100644 --- a/backend/requirements.in +++ b/backend/requirements.in @@ -4,4 +4,4 @@ httpx pyyaml pydantic-settings sqlmodel -passlib[argon2] \ No newline at end of file +argon2-cffi \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt index 6131b3f..c06f8dc 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -17,7 +17,7 @@ anyio==4.10.0 \ argon2-cffi==25.1.0 \ --hash=sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1 \ --hash=sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741 - # via passlib + # via -r requirements.in argon2-cffi-bindings==25.1.0 \ --hash=sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99 \ --hash=sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6 \ @@ -222,10 +222,6 @@ idna==3.10 \ # via # anyio # httpx -passlib[argon2]==1.7.4 \ - --hash=sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1 \ - --hash=sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04 - # via -r requirements.in pycparser==2.23 \ --hash=sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2 \ --hash=sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934 diff --git a/backend/settings.py b/backend/settings.py index 25ad0dc..62305be 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from pathlib import Path from typing import Any @@ -13,6 +15,7 @@ class Settings(BaseSettings): workers: int = 1 log_level: str = "info" database_url: str = "sqlite:///:memory:" + hmac_key: str | None = None model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") diff --git a/backend/tests/test_security.py b/backend/tests/test_security.py index cab62d7..bfd5e63 100644 --- a/backend/tests/test_security.py +++ b/backend/tests/test_security.py @@ -1,4 +1,24 @@ from trading_journal import security -def test_hash_password() -> None: + +def test_hash_and_verify_password() -> None: plain = "password" + hashed = security.hash_password(plain) + assert hashed != plain + assert security.verify_password(plain, hashed) + + +def test_generate_session_token() -> None: + token1 = security.generate_session_token() + token2 = security.generate_session_token() + assert token1 != token2 + assert len(token1) > 0 + assert len(token2) > 0 + + +def test_hash_and_verify_session_token_sha256() -> None: + token = security.generate_session_token() + token_hash = security.hash_session_token_sha256(token) + assert token_hash != token + assert security.verify_token_sha256(token, token_hash) + assert not security.verify_token_sha256(token + "x", token_hash) diff --git a/backend/trading_journal/security.py b/backend/trading_journal/security.py index edb32ce..23bf03a 100644 --- a/backend/trading_journal/security.py +++ b/backend/trading_journal/security.py @@ -1,11 +1,51 @@ -from passlib.context import CryptContext +import hashlib +import hmac +import secrets -pwd_ctx = CryptContext(schemes=["argon2"], deprecated="auto") +from argon2 import PasswordHasher +from argon2.exceptions import VerifyMismatchError + +import settings + +ph = PasswordHasher() + +# Utility functions for password hashing and verification def hash_password(plain: str) -> str: - return pwd_ctx.hash(plain) + return ph.hash(plain) + def verify_password(plain: str, hashed: str) -> bool: - return pwd_ctx.verify(plain, hashed) + try: + return ph.verify(hashed, plain) + except VerifyMismatchError: + return False + +# Session token hash + + +def generate_session_token(nbytes: int = 32) -> str: + return secrets.token_urlsafe(nbytes) + + +def hash_session_token_sha256(token: str) -> str: + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + +def sign_token_hmac(token: str) -> str: + if not settings.settings.hmac_key: + return token + return hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest() + + +def verify_token_sha256(token: str, expected_hash: str) -> bool: + return hmac.compare_digest(hash_session_token_sha256(token), expected_hash) + + +def verify_token_hmac(token: str, expected_hmac: str) -> bool: + if not settings.settings.hmac_key: + return verify_token_sha256(token, expected_hmac) + sig = hmac.new(settings.settings.hmac_key.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest() + return hmac.compare_digest(sig, expected_hmac) -- 2.49.1 From 466e6ce653c1912d564b45a216f9253957a4f231 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 22 Sep 2025 17:35:10 +0200 Subject: [PATCH 06/18] wip user reg --- backend/app.py | 34 +++++++++++--- backend/settings.py | 1 + backend/tests/test_app.py | 5 ++- backend/tests/test_crud.py | 26 +++++++++++ backend/trading_journal/crud.py | 20 +++++++++ backend/trading_journal/db.py | 13 ++++++ backend/trading_journal/service.py | 72 ++++++++++++++++++++++++++++++ 7 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 backend/trading_journal/service.py diff --git a/backend/app.py b/backend/app.py index 812f896..f485bb2 100644 --- a/backend/app.py +++ b/backend/app.py @@ -2,13 +2,12 @@ import asyncio from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from fastapi import FastAPI, status +from fastapi import FastAPI, HTTPException, Request import settings -from trading_journal import db -from trading_journal.dto import TradeCreate, TradeRead - -API_BASE = "/api/v1" +from trading_journal import db, service +from trading_journal.db import Database +from trading_journal.dto import UserCreate, UserRead _db = db.create_database(settings.settings.database_url) @@ -23,8 +22,31 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 app = FastAPI(lifespan=lifespan) +app.add_middleware(service.AuthMiddleWare) +app.state.db_factory = _db -@app.get(f"{API_BASE}/status") +@app.get(f"{settings.settings.api_base}/status") async def get_status() -> dict[str, str]: return {"status": "ok"} + + +@app.post(f"{settings.settings.api_base}/register") +async def register_user(request: Request, user_in: UserCreate) -> UserRead: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> UserRead: + with db_factory.get_session_ctx_manager() as db: + return service.register_user_service(db, user_in) + + try: + return await asyncio.to_thread(sync_work) + except service.UserAlreadyExistsError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail="Internal server error" + str(e)) from e + + +@app.get(f"{settings.settings.api_base}/trades") +async def get_trades() -> dict[str, str]: + return {"trades": []} diff --git a/backend/settings.py b/backend/settings.py index 62305be..1e1e29f 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -15,6 +15,7 @@ class Settings(BaseSettings): workers: int = 1 log_level: str = "info" database_url: str = "sqlite:///:memory:" + api_base: str = "/api/v1" hmac_key: str | None = None model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") diff --git a/backend/tests/test_app.py b/backend/tests/test_app.py index d6123a6..78cf8ad 100644 --- a/backend/tests/test_app.py +++ b/backend/tests/test_app.py @@ -3,7 +3,8 @@ from collections.abc import Generator import pytest from fastapi.testclient import TestClient -from app import API_BASE, app +import settings +from app import app @pytest.fixture @@ -13,6 +14,6 @@ def client() -> Generator[TestClient, None, None]: def test_get_status(client: TestClient) -> None: - response = client.get(f"{API_BASE}/status") + response = client.get(f"{settings.settings.api_base}/status") assert response.status_code == 200 assert response.json() == {"status": "ok"} diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 1ae5a55..9e0fade 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -651,6 +651,22 @@ def test_create_user(session: Session) -> None: assert actual_user.password_hash == user_data["password_hash"] +def test_get_user_by_id(session: Session) -> None: + user_id = make_user(session, username="fetchuser") + user = crud.get_user_by_id(session, user_id) + assert user is not None + assert user.id == user_id + assert user.username == "fetchuser" + + +def test_get_user_by_username(session: Session) -> None: + username = "uniqueuser" + make_user(session, username=username) + user = crud.get_user_by_username(session, username) + assert user is not None + assert user.username == username + + def test_update_user(session: Session) -> None: user_id = make_user(session, username="updatableuser") @@ -716,6 +732,16 @@ def test_get_login_session_by_token_and_user_id(session: Session) -> None: assert fetched_session.session_token_hash == created_session.session_token_hash +def test_get_login_session_by_token(session: Session) -> None: + now = datetime.now(timezone.utc) + created_session = make_login_session(session, now) + fetched_session = crud.get_login_session_by_token_hash(session, created_session.session_token_hash) + assert fetched_session is not None + assert fetched_session.id == created_session.id + assert fetched_session.user_id == created_session.user_id + assert fetched_session.session_token_hash == created_session.session_token_hash + + def test_update_login_session(session: Session) -> None: now = datetime.now(timezone.utc) created_session = make_login_session(session, now) diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 9e998bd..d21e157 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -349,6 +349,17 @@ def create_user(session: Session, user_data: Mapping) -> models.Users: return u +def get_user_by_id(session: Session, user_id: int) -> models.Users | None: + return session.get(models.Users, user_id) + + +def get_user_by_username(session: Session, username: str) -> models.Users | None: + statement = select(models.Users).where( + models.Users.username == username, + ) + return session.exec(statement).first() + + def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users: user: models.Users | None = session.get(models.Users, user_id) if user is None: @@ -418,6 +429,15 @@ def get_login_session_by_token_hash_and_user_id(session: Session, session_token_ return session.exec(statement).first() +def get_login_session_by_token_hash(session: Session, session_token_hash: str) -> models.Sessions | None: + statement = select(models.Sessions).where( + models.Sessions.session_token_hash == session_token_hash, + models.Sessions.expires_at > datetime.now(timezone.utc), + ) + + return session.exec(statement).first() + + IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} diff --git a/backend/trading_journal/db.py b/backend/trading_journal/db.py index 039727c..93503f3 100644 --- a/backend/trading_journal/db.py +++ b/backend/trading_journal/db.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from contextlib import contextmanager from typing import TYPE_CHECKING from sqlalchemy import event @@ -72,6 +73,18 @@ class Database: finally: session.close() + @contextmanager + def get_session_ctx_manager(self) -> Session: + session = Session(self._engine) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + def dispose(self) -> None: self._engine.dispose() diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py new file mode 100644 index 0000000..404ffb8 --- /dev/null +++ b/backend/trading_journal/service.py @@ -0,0 +1,72 @@ +from typing import Callable + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from sqlmodel import Session +from starlette.middleware.base import BaseHTTPMiddleware + +import settings +from trading_journal import crud, security +from trading_journal.db import Database +from trading_journal.dto import UserCreate, UserRead +from trading_journal.models import Sessions + +EXCEPT_PATHS = [ + f"{settings.settings.api_base}/status", + f"{settings.settings.api_base}/register", +] + + +class AuthMiddleWare(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: + if request.url.path in EXCEPT_PATHS: + return await call_next(request) + + token = request.cookies.get("session_token") + if not token: + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header[len("Bearer ") :] + + if not token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Unauthorized"}, + ) + + db_factory: Database | None = getattr(request.app.state, "db_factory", None) + if db_factory is None: + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db factory not configured"}) + try: + with db_factory.get_session_ctx_manager() as request_session: + hashed_token = security.hash_session_token_sha256(token) + request.state.db_session = request_session + login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token) + except Exception: # noqa: BLE001 + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db error"}) + + return None + + +class ServiceError(Exception): + pass + + +class UserAlreadyExistsError(ServiceError): + pass + + +def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: + if crud.get_user_by_username(db_session, user_in.username): + raise UserAlreadyExistsError("username already exists") + hashed = security.hash_password(user_in.password) + try: + user = crud.create_user(db_session, username=user_in.username, hashed_password=hashed) + try: + # prefer pydantic's from_orm if DTO supports orm_mode + user = UserRead.model_validate(user) + except Exception as e: + raise ServiceError("Failed to convert user to UserRead") from e + except Exception as e: + raise ServiceError("Failed to create user") from e + return user -- 2.49.1 From 1750401278d8df80f2931e5e16bbc5fe31c176ce Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 22 Sep 2025 22:51:59 +0200 Subject: [PATCH 07/18] several changes: * api calls for auth * exchange now bind to user --- backend/.gitignore | 4 +- backend/app.py | 69 +++++++++++++++++++--- backend/settings.py | 1 + backend/tests/test_crud.py | 66 ++++++++++++++------- backend/tests/test_db_migration.py | 10 +++- backend/trading_journal/crud.py | 10 +++- backend/trading_journal/db.py | 5 +- backend/trading_journal/db_migration.py | 1 + backend/trading_journal/dto.py | 32 +++++++++- backend/trading_journal/models.py | 8 ++- backend/trading_journal/models_v1.py | 8 ++- backend/trading_journal/service.py | 77 +++++++++++++++++++++++-- backend/utils/__init__.py | 0 backend/utils/db_mirgration.py | 13 +++++ 14 files changed, 259 insertions(+), 45 deletions(-) create mode 100644 backend/utils/__init__.py create mode 100644 backend/utils/db_mirgration.py diff --git a/backend/.gitignore b/backend/.gitignore index 6321b92..837cf41 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -14,4 +14,6 @@ __pycache__/ *.db *.db-shm -*.db-wal \ No newline at end of file +*.db-wal + +devsettings.yaml \ No newline at end of file diff --git a/backend/app.py b/backend/app.py index f485bb2..93a9186 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,16 +1,27 @@ +from __future__ import annotations + import asyncio +import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from datetime import datetime, timezone -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request, status +from fastapi.responses import JSONResponse import settings from trading_journal import db, service from trading_journal.db import Database -from trading_journal.dto import UserCreate, UserRead +from trading_journal.dto import SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead _db = db.create_database(settings.settings.database_url) +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 @@ -40,13 +51,57 @@ async def register_user(request: Request, user_in: UserCreate) -> UserRead: return service.register_user_service(db, user_in) try: - return await asyncio.to_thread(sync_work) + user = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=user.model_dump()) except service.UserAlreadyExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) from e + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e except Exception as e: - raise HTTPException(status_code=500, detail="Internal server error" + str(e)) from e + logger.exception("Failed to register user: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.post(f"{settings.settings.api_base}/login") +async def login(request: Request, user_in: UserLogin) -> SessionsBase: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> tuple[SessionsCreate, str] | None: + with db_factory.get_session_ctx_manager() as db: + return service.authenticate_user_service(db, user_in) + + try: + result = await asyncio.to_thread(sync_work) + if result is None: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Invalid username or password, or user doesn't exist"}, + ) + session, token = result + session_return = SessionsBase(user_id=session.user_id) + response = JSONResponse(status_code=status.HTTP_200_OK, content=session_return.model_dump()) + expires_sec = int((session.expires_at.replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)).total_seconds()) + response.set_cookie( + key="session_token", + value=token, + httponly=True, + secure=True, + samesite="lax", + max_age=expires_sec, + path="/", + ) + except Exception as e: + logger.exception("Failed to login user: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + else: + return response + + +# Exchange +# @app.post(f"{settings.settings.api_base}/exchanges") +# async def create_exchange(request: Request, name: str, notes: str | None) -> dict: @app.get(f"{settings.settings.api_base}/trades") -async def get_trades() -> dict[str, str]: - return {"trades": []} +async def get_trades(request: Request) -> list: + db_factory: Database = request.app.state.db_factory + with db_factory.get_session_ctx_manager() as db: + return service.get_trades_service(db, request.state.user_id) diff --git a/backend/settings.py b/backend/settings.py index 1e1e29f..eff1071 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -16,6 +16,7 @@ class Settings(BaseSettings): log_level: str = "info" database_url: str = "sqlite:///:memory:" api_base: str = "/api/v1" + session_expiry_seconds: int = 3600 * 24 * 7 # 7 days hmac_key: str | None = None model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8") diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 9e0fade..3e02227 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -48,8 +48,8 @@ def make_user(session: Session, username: str = "testuser") -> int: return user.id -def make_exchange(session: Session, name: str = "NASDAQ") -> int: - exchange = models.Exchanges(name=name, notes="Test exchange") +def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int: + exchange = models.Exchanges(user_id=user_id, name=name, notes="Test exchange") session.add(exchange) session.commit() session.refresh(exchange) @@ -138,7 +138,7 @@ def _ensure_utc_aware(dt: datetime) -> datetime | None: def test_create_trade_success_with_cycle(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_data = { @@ -180,7 +180,7 @@ def test_create_trade_success_with_cycle(session: Session) -> None: def test_create_trade_with_auto_created_cycle(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) trade_data = { "user_id": user_id, @@ -224,7 +224,7 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None: def test_create_trade_missing_required_fields(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) base_trade_data = { "user_id": user_id, @@ -291,7 +291,7 @@ def test_create_trade_missing_required_fields(session: Session) -> None: def test_get_trade_by_id(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_data = { "user_id": user_id, @@ -330,7 +330,7 @@ def test_get_trade_by_id(session: Session) -> None: def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) friendly_name = "Unique Trade Name" trade_data = { @@ -359,7 +359,7 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session) -> None: def test_get_trades_by_user_id(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_data_1 = { "user_id": user_id, @@ -406,7 +406,7 @@ def test_get_trades_by_user_id(session: Session) -> None: def test_update_trade_note(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_id = make_trade(session, user_id, cycle_id) @@ -424,7 +424,7 @@ def test_update_trade_note(session: Session) -> None: def test_invalidate_trade(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) trade_id = make_trade(session, user_id, cycle_id) @@ -441,7 +441,7 @@ def test_invalidate_trade(session: Session) -> None: def test_replace_trade(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id) old_trade_id = make_trade(session, user_id, cycle_id) @@ -486,7 +486,7 @@ def test_replace_trade(session: Session) -> None: def test_create_cycle(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_data = { "user_id": user_id, "friendly_name": "My First Cycle", @@ -517,7 +517,7 @@ def test_create_cycle(session: Session) -> None: def test_update_cycle(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") update_data = { @@ -539,7 +539,7 @@ def test_update_cycle(session: Session) -> None: def test_update_cycle_immutable_fields(session: Session) -> None: user_id = make_user(session) - exchange_id = make_exchange(session) + exchange_id = make_exchange(session, user_id) cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Initial Cycle Name") # Attempt to update immutable fields @@ -563,42 +563,51 @@ def test_update_cycle_immutable_fields(session: Session) -> None: # Exchanges def test_create_exchange(session: Session) -> None: + user_id = make_user(session) exchange_data = { "name": "NYSE", "notes": "New York Stock Exchange", + "user_id": user_id, } exchange = crud.create_exchange(session, exchange_data) assert exchange.id is not None assert exchange.name == exchange_data["name"] assert exchange.notes == exchange_data["notes"] + assert exchange.user_id == user_id session.refresh(exchange) actual_exchange = session.get(models.Exchanges, exchange.id) assert actual_exchange is not None assert actual_exchange.name == exchange_data["name"] assert actual_exchange.notes == exchange_data["notes"] + assert actual_exchange.user_id == user_id def test_get_exchange_by_id(session: Session) -> None: - exchange_id = make_exchange(session, name="LSE") + user_id = make_user(session) + exchange_id = make_exchange(session, user_id=user_id, name="LSE") exchange = crud.get_exchange_by_id(session, exchange_id) assert exchange is not None assert exchange.id == exchange_id assert exchange.name == "LSE" + assert exchange.user_id == user_id -def test_get_exchange_by_name(session: Session) -> None: +def test_get_exchange_by_name_and_user_id(session: Session) -> None: exchange_name = "TSX" - make_exchange(session, name=exchange_name) - exchange = crud.get_exchange_by_name(session, exchange_name) + user_id = make_user(session) + make_exchange(session, user_id=user_id, name=exchange_name) + exchange = crud.get_exchange_by_name_and_user_id(session, exchange_name, user_id) assert exchange is not None assert exchange.name == exchange_name + assert exchange.user_id == user_id def test_get_all_exchanges(session: Session) -> None: exchange_names = ["NYSE", "NASDAQ", "LSE"] + user_id = make_user(session) for name in exchange_names: - make_exchange(session, name=name) + make_exchange(session, user_id=user_id, name=name) exchanges = crud.get_all_exchanges(session) assert len(exchanges) >= 3 @@ -607,8 +616,22 @@ def test_get_all_exchanges(session: Session) -> None: assert name in fetched_names +def test_get_all_exchanges_by_user_id(session: Session) -> None: + exchange_names = ["NYSE", "NASDAQ"] + user_id = make_user(session) + for name in exchange_names: + make_exchange(session, user_id=user_id, name=name) + + exchanges = crud.get_all_exchanges_by_user_id(session, user_id) + assert len(exchanges) == len(exchange_names) + fetched_names = {ex.name for ex in exchanges} + for name in exchange_names: + assert name in fetched_names + + def test_update_exchange(session: Session) -> None: - exchange_id = make_exchange(session, name="Initial Exchange") + user_id = make_user(session) + exchange_id = make_exchange(session, user_id=user_id, name="Initial Exchange") update_data = { "name": "Updated Exchange", "notes": "Updated notes for the exchange", @@ -627,7 +650,8 @@ def test_update_exchange(session: Session) -> None: def test_delete_exchange(session: Session) -> None: - exchange_id = make_exchange(session, name="Deletable Exchange") + user_id = make_user(session) + exchange_id = make_exchange(session, user_id=user_id, name="Deletable Exchange") crud.delete_exchange(session, exchange_id) deleted_exchange = session.get(models.Exchanges, exchange_id) assert deleted_exchange is None diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index 15c7fba..343214b 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -70,6 +70,12 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "notes": ("TEXT", 0, 0), "cycle_id": ("INTEGER", 0, 0), }, + "exchanges": { + "id": ("INTEGER", 1, 1), + "user_id": ("INTEGER", 1, 0), + "name": ("TEXT", 1, 0), + "notes": ("TEXT", 0, 0), + }, "sessions": { "id": ("INTEGER", 1, 1), "user_id": ("INTEGER", 1, 0), @@ -97,7 +103,9 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: {"table": "users", "from": "user_id", "to": "id"}, ], "users": [], - "exchanges": [], + "exchanges": [ + {"table": "users", "from": "user_id", "to": "id"}, + ], } with engine.connect() as conn: diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index d21e157..75bb8ad 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -274,9 +274,10 @@ def get_exchange_by_id(session: Session, exchange_id: int) -> models.Exchanges | return session.get(models.Exchanges, exchange_id) -def get_exchange_by_name(session: Session, name: str) -> models.Exchanges | None: +def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) -> models.Exchanges | None: statement = select(models.Exchanges).where( models.Exchanges.name == name, + models.Exchanges.user_id == user_id, ) return session.exec(statement).first() @@ -286,6 +287,13 @@ def get_all_exchanges(session: Session) -> list[models.Exchanges]: return session.exec(statement).all() +def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]: + statement = select(models.Exchanges).where( + models.Exchanges.user_id == user_id, + ) + return session.exec(statement).all() + + def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges: exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) if exchange is None: diff --git a/backend/trading_journal/db.py b/backend/trading_journal/db.py index 93503f3..ddd2dc6 100644 --- a/backend/trading_journal/db.py +++ b/backend/trading_journal/db.py @@ -8,8 +8,6 @@ from sqlalchemy import event from sqlalchemy.pool import StaticPool from sqlmodel import Session, create_engine -from trading_journal import db_migration - if TYPE_CHECKING: from collections.abc import Generator from sqlite3 import Connection as DBAPIConnection @@ -59,7 +57,6 @@ class Database: event.listen(self._engine, "connect", _enable_sqlite_pragmas) def init_db(self) -> None: - # db_migration.run_migrations(self._engine) pass def get_session(self) -> Generator[Session, None, None]: @@ -74,7 +71,7 @@ class Database: session.close() @contextmanager - def get_session_ctx_manager(self) -> Session: + def get_session_ctx_manager(self) -> Generator[Session, None, None]: session = Session(self._engine) try: yield session diff --git a/backend/trading_journal/db_migration.py b/backend/trading_journal/db_migration.py index 2a57464..e3766a2 100644 --- a/backend/trading_journal/db_migration.py +++ b/backend/trading_journal/db_migration.py @@ -27,6 +27,7 @@ def _mig_0_1(engine: Engine) -> None: models_v1.Cycles.__table__, models_v1.Users.__table__, models_v1.Sessions.__table__, + models_v1.Exchanges.__table__, ], ) diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 1a9d478..9c4ce25 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -1,12 +1,12 @@ from __future__ import annotations +from datetime import date, datetime # noqa: TC003 from typing import TYPE_CHECKING +from pydantic import BaseModel from sqlmodel import SQLModel if TYPE_CHECKING: - from datetime import date, datetime - from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency @@ -52,5 +52,33 @@ class UserCreate(UserBase): password: str +class UserLogin(BaseModel): + username: str + password: str + + class UserRead(UserBase): id: int + + +class SessionsBase(SQLModel): + user_id: int + + +class SessionRead(SessionsBase): + id: int + expires_at: datetime + last_seen_at: datetime | None + last_used_ip: str | None + user_agent: str | None + + +class SessionsCreate(SessionsBase): + expires_at: datetime + + +class SessionsUpdate(SQLModel): + expires_at: datetime | None = None + last_seen_at: datetime | None = None + last_used_ip: str | None = None + user_agent: str | None = None diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 0238a81..9bdc57a 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -117,11 +117,14 @@ class Cycles(SQLModel, table=True): class Exchanges(SQLModel, table=True): __tablename__ = "exchanges" + __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) id: int | None = Field(default=None, primary_key=True) - name: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + name: str = Field(sa_column=Column(Text, nullable=False)) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) trades: list["Trades"] = Relationship(back_populates="exchange") cycles: list["Cycles"] = Relationship(back_populates="exchange") + user: "Users" = Relationship(back_populates="exchanges") class Users(SQLModel, table=True): @@ -131,6 +134,8 @@ class Users(SQLModel, table=True): username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) password_hash: str = Field(sa_column=Column(Text, nullable=False)) is_active: bool = Field(default=True, nullable=False) + sessions: list["Sessions"] = Relationship(back_populates="user") + exchanges: list["Exchanges"] = Relationship(back_populates="user") class Sessions(SQLModel, table=True): @@ -144,3 +149,4 @@ class Sessions(SQLModel, table=True): last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + user: "Users" = Relationship(back_populates="sessions") diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 0238a81..9bdc57a 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -117,11 +117,14 @@ class Cycles(SQLModel, table=True): class Exchanges(SQLModel, table=True): __tablename__ = "exchanges" + __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) id: int | None = Field(default=None, primary_key=True) - name: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + name: str = Field(sa_column=Column(Text, nullable=False)) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) trades: list["Trades"] = Relationship(back_populates="exchange") cycles: list["Cycles"] = Relationship(back_populates="exchange") + user: "Users" = Relationship(back_populates="exchanges") class Users(SQLModel, table=True): @@ -131,6 +134,8 @@ class Users(SQLModel, table=True): username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) password_hash: str = Field(sa_column=Column(Text, nullable=False)) is_active: bool = Field(default=True, nullable=False) + sessions: list["Sessions"] = Relationship(back_populates="user") + exchanges: list["Exchanges"] = Relationship(back_populates="user") class Sessions(SQLModel, table=True): @@ -144,3 +149,4 @@ class Sessions(SQLModel, table=True): last_used_ip: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) user_agent: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) device_name: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + user: "Users" = Relationship(back_populates="sessions") diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 404ffb8..8fd8cd4 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timedelta, timezone from typing import Callable from fastapi import Request, Response, status @@ -8,17 +12,23 @@ from starlette.middleware.base import BaseHTTPMiddleware import settings from trading_journal import crud, security from trading_journal.db import Database -from trading_journal.dto import UserCreate, UserRead +from trading_journal.dto import SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead from trading_journal.models import Sessions +SessionsCreate.model_rebuild() + + EXCEPT_PATHS = [ f"{settings.settings.api_base}/status", f"{settings.settings.api_base}/register", + f"{settings.settings.api_base}/login", ] +logger = logging.getLogger(__name__) + class AuthMiddleWare(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: + async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: # noqa: PLR0911 if request.url.path in EXCEPT_PATHS: return await call_next(request) @@ -42,10 +52,32 @@ class AuthMiddleWare(BaseHTTPMiddleware): hashed_token = security.hash_session_token_sha256(token) request.state.db_session = request_session login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token) - except Exception: # noqa: BLE001 - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "db error"}) + if not login_session: + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) + session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc) + if session_expires_utc < datetime.now(timezone.utc): + crud.delete_login_session(request.state.db_session, login_session) + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) + if login_session.user.is_active is False: + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) + if session_expires_utc - datetime.now(timezone.utc) < timedelta(seconds=3600): + updated_expiry = datetime.now(timezone.utc) + timedelta(seconds=settings.settings.session_expiry_seconds) + else: + updated_expiry = session_expires_utc + updated_session: SessionsUpdate = SessionsUpdate( + last_seen_at=datetime.now(timezone.utc), + last_used_ip=request.client.host if request.client else None, + user_agent=request.headers.get("User-Agent"), + expires_at=updated_expiry, + ) + user_id = login_session.user_id + request.state.user_id = user_id + crud.update_login_session(request.state.db_session, hashed_token, update_session=updated_session) + except Exception: + logger.exception("Failed to authenticate user: \n") + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"}) - return None + return await call_next(request) class ServiceError(Exception): @@ -60,13 +92,46 @@ def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): raise UserAlreadyExistsError("username already exists") hashed = security.hash_password(user_in.password) + user_data: dict = { + "username": user_in.username, + "password_hash": hashed, + } try: - user = crud.create_user(db_session, username=user_in.username, hashed_password=hashed) + user = crud.create_user(db_session, user_data=user_data) try: # prefer pydantic's from_orm if DTO supports orm_mode user = UserRead.model_validate(user) except Exception as e: + logger.exception("Failed to convert user to UserRead: %s", e) raise ServiceError("Failed to convert user to UserRead") from e except Exception as e: + logger.exception("Failed to create user:") raise ServiceError("Failed to create user") from e return user + + +def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[SessionsCreate, str] | None: + user = crud.get_user_by_username(db_session, user_in.username) + if not user: + return None + + if not security.verify_password(user_in.password, user.password_hash): + return None + + token = security.generate_session_token() + token_hashed = security.hash_session_token_sha256(token) + try: + session = crud.create_login_session( + session=db_session, + user_id=user.id, + session_token_hash=token_hashed, + session_length_seconds=settings.settings.session_expiry_seconds, + ) + except Exception as e: + logger.exception("Failed to create login session: \n") + raise ServiceError("Failed to create login session") from e + return SessionsCreate.model_validate(session), token + + +def get_trades_service(db_session: Session, user_id: int) -> list: + return crud.get_trades_by_user_id(db_session, user_id) diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/utils/db_mirgration.py b/backend/utils/db_mirgration.py new file mode 100644 index 0000000..1103349 --- /dev/null +++ b/backend/utils/db_mirgration.py @@ -0,0 +1,13 @@ +import sys +from pathlib import Path + +from sqlmodel import create_engine + +project_parent = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(project_parent)) + +import settings # noqa: E402 +from trading_journal import db_migration # noqa: E402 + +db_engine = create_engine(settings.settings.database_url, echo=True) +db_migration.run_migrations(db_engine) -- 2.49.1 From b68249f9f143655f6bb7cc5d33fa89a094560a1a Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 22 Sep 2025 23:07:28 +0200 Subject: [PATCH 08/18] add create get exchange endpoint --- backend/app.py | 43 +++++++++++++++++++++++------- backend/trading_journal/dto.py | 9 +++++++ backend/trading_journal/service.py | 34 ++++++++++++++++++++++- 3 files changed, 76 insertions(+), 10 deletions(-) diff --git a/backend/app.py b/backend/app.py index 93a9186..1d30b28 100644 --- a/backend/app.py +++ b/backend/app.py @@ -12,7 +12,7 @@ from fastapi.responses import JSONResponse import settings from trading_journal import db, service from trading_journal.db import Database -from trading_journal.dto import SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead +from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead _db = db.create_database(settings.settings.database_url) @@ -96,12 +96,37 @@ async def login(request: Request, user_in: UserLogin) -> SessionsBase: # Exchange -# @app.post(f"{settings.settings.api_base}/exchanges") -# async def create_exchange(request: Request, name: str, notes: str | None) -> dict: - - -@app.get(f"{settings.settings.api_base}/trades") -async def get_trades(request: Request) -> list: +@app.post(f"{settings.settings.api_base}/exchanges") +async def create_exchange(request: Request, exchange_data: ExchangesBase) -> dict: db_factory: Database = request.app.state.db_factory - with db_factory.get_session_ctx_manager() as db: - return service.get_trades_service(db, request.state.user_id) + + def sync_work() -> ExchangesBase: + with db_factory.get_session_ctx_manager() as db: + return service.create_exchange_service(db, request.state.user_id, exchange_data.name, exchange_data.notes) + + try: + exchange = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=exchange.model_dump()) + except service.ExchangeAlreadyExistsError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to create exchange: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.get(f"{settings.settings.api_base}/exchanges") +async def get_exchanges(request: Request) -> list[ExchangesBase]: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> list[ExchangesBase]: + with db_factory.get_session_ctx_manager() as db: + return service.get_exchanges_by_user_service(db, request.state.user_id) + + try: + return await asyncio.to_thread(sync_work) + except Exception as e: + logger.exception("Failed to get exchanges: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +# Trade diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 9c4ce25..5ea709c 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -10,6 +10,15 @@ if TYPE_CHECKING: from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency +class ExchangesBase(SQLModel): + name: str + notes: str | None = None + + +class ExchangesCreate(ExchangesBase): + user_id: int + + class TradeBase(SQLModel): user_id: int friendly_name: str | None diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 8fd8cd4..881fe53 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -12,7 +12,7 @@ from starlette.middleware.base import BaseHTTPMiddleware import settings from trading_journal import crud, security from trading_journal.db import Database -from trading_journal.dto import SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead +from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead from trading_journal.models import Sessions SessionsCreate.model_rebuild() @@ -88,6 +88,10 @@ class UserAlreadyExistsError(ServiceError): pass +class ExchangeAlreadyExistsError(ServiceError): + pass + + def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): raise UserAlreadyExistsError("username already exists") @@ -133,5 +137,33 @@ def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[ return SessionsCreate.model_validate(session), token +# Exchanges service +def create_exchange_service(db_session: Session, user_id: int, name: str, notes: str | None) -> ExchangesCreate: + existing_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id) + if existing_exchange: + raise ExchangeAlreadyExistsError("Exchange with the same name already exists for this user") + exchange_data = ExchangesCreate( + user_id=user_id, + name=name, + notes=notes, + ) + try: + exchange = crud.create_exchange(db_session, exchange_data=exchange_data) + try: + exchange_dto = ExchangesCreate.model_validate(exchange) + except Exception as e: + logger.exception("Failed to convert exchange to ExchangesCreate: ") + raise ServiceError("Failed to convert exchange to ExchangesCreate") from e + except Exception as e: + logger.exception("Failed to create exchange:") + raise ServiceError("Failed to create exchange") from e + return exchange_dto + + +def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesBase]: + exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id) + return [ExchangesBase.model_validate(exchange) for exchange in exchanges] + + def get_trades_service(db_session: Session, user_id: int) -> list: return crud.get_trades_by_user_id(db_session, user_id) -- 2.49.1 From 92c4e0d4fc5b47732e2b3343e6056dbcd9cde555 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 23 Sep 2025 17:37:14 +0200 Subject: [PATCH 09/18] refine type checking --- backend/.vscode/settings.json | 5 +- backend/app.py | 16 +-- backend/tests/test_crud.py | 21 ++-- backend/trading_journal/crud.py | 137 ++++++++++++++---------- backend/trading_journal/db_migration.py | 10 +- backend/trading_journal/models.py | 10 +- backend/trading_journal/models_v1.py | 10 +- backend/trading_journal/service.py | 26 +++-- 8 files changed, 132 insertions(+), 103 deletions(-) diff --git a/backend/.vscode/settings.json b/backend/.vscode/settings.json index 96661cd..839030b 100644 --- a/backend/.vscode/settings.json +++ b/backend/.vscode/settings.json @@ -11,5 +11,6 @@ "tests" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} + "python.testing.pytestEnabled": true, + "python.analysis.typeCheckingMode": "standard", +} \ No newline at end of file diff --git a/backend/app.py b/backend/app.py index 1d30b28..4c6f61d 100644 --- a/backend/app.py +++ b/backend/app.py @@ -2,18 +2,22 @@ from __future__ import annotations import asyncio import logging -from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from datetime import datetime, timezone +from typing import TYPE_CHECKING from fastapi import FastAPI, HTTPException, Request, status -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, Response import settings from trading_journal import db, service -from trading_journal.db import Database from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from trading_journal.db import Database + _db = db.create_database(settings.settings.database_url) logging.basicConfig( @@ -43,7 +47,7 @@ async def get_status() -> dict[str, str]: @app.post(f"{settings.settings.api_base}/register") -async def register_user(request: Request, user_in: UserCreate) -> UserRead: +async def register_user(request: Request, user_in: UserCreate) -> Response: db_factory: Database = request.app.state.db_factory def sync_work() -> UserRead: @@ -61,7 +65,7 @@ async def register_user(request: Request, user_in: UserCreate) -> UserRead: @app.post(f"{settings.settings.api_base}/login") -async def login(request: Request, user_in: UserLogin) -> SessionsBase: +async def login(request: Request, user_in: UserLogin) -> Response: db_factory: Database = request.app.state.db_factory def sync_work() -> tuple[SessionsCreate, str] | None: @@ -97,7 +101,7 @@ async def login(request: Request, user_in: UserLogin) -> SessionsBase: # Exchange @app.post(f"{settings.settings.api_base}/exchanges") -async def create_exchange(request: Request, exchange_data: ExchangesBase) -> dict: +async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Response: db_factory: Database = request.app.state.db_factory def sync_work() -> ExchangesBase: diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 3e02227..0620fed 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import pytest from sqlalchemy import create_engine @@ -45,7 +45,7 @@ def make_user(session: Session, username: str = "testuser") -> int: session.add(user) session.commit() session.refresh(user) - return user.id + return cast("int", user.id) def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int: @@ -53,7 +53,7 @@ def make_exchange(session: Session, user_id: int, name: str = "NASDAQ") -> int: session.add(exchange) session.commit() session.refresh(exchange) - return exchange.id + return cast("int", exchange.id) def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: str = "Test Cycle") -> int: @@ -65,15 +65,16 @@ def make_cycle(session: Session, user_id: int, exchange_id: int, friendly_name: underlying_currency=models.UnderlyingCurrency.USD, status=models.CycleStatus.OPEN, start_date=datetime.now(timezone.utc).date(), - ) + ) # type: ignore[arg-type] session.add(cycle) session.commit() session.refresh(cycle) - return cycle.id + return cast("int", cycle.id) def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade") -> int: - cycle: models.Cycles = session.get(models.Cycles, cycle_id) + cycle: models.Cycles | None = session.get(models.Cycles, cycle_id) + assert cycle is not None exchange_id = cycle.exchange_id trade = models.Trades( user_id=user_id, @@ -96,7 +97,7 @@ def make_trade(session: Session, user_id: int, cycle_id: int, friendly_name: str session.add(trade) session.commit() session.refresh(trade) - return trade.id + return cast("int", trade.id) def make_trade_by_trade_data(session: Session, trade_data: dict) -> int: @@ -104,7 +105,7 @@ def make_trade_by_trade_data(session: Session, trade_data: dict) -> int: session.add(trade) session.commit() session.refresh(trade) - return trade.id + return cast("int", trade.id) def make_login_session(session: Session, created_at: datetime) -> models.Sessions: @@ -128,7 +129,7 @@ def make_login_session(session: Session, created_at: datetime) -> models.Session return login_session -def _ensure_utc_aware(dt: datetime) -> datetime | None: +def _ensure_utc_aware(dt: datetime | None) -> datetime | None: if dt is None: return None if dt.tzinfo is None: @@ -219,7 +220,7 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None: assert auto_cycle.symbol == trade_data["symbol"] assert auto_cycle.underlying_currency == trade_data["underlying_currency"] assert auto_cycle.status == models.CycleStatus.OPEN - assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") + assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") # type: ignore[union-attr] def test_create_trade_missing_required_fields(session: Session) -> None: diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 75bb8ad..1918b5d 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -1,8 +1,9 @@ from __future__ import annotations from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, TypeVar, cast +from pydantic import BaseModel from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select @@ -10,9 +11,14 @@ from trading_journal import models if TYPE_CHECKING: from collections.abc import Mapping + from enum import Enum -def _check_enum(enum_cls: any, value: any, field_name: str) -> any: +# Generic enum member type +T = TypeVar("T", bound="Enum") + + +def _check_enum(enum_cls: type[T], value: object, field_name: str) -> T: if value is None: raise ValueError(f"{field_name} is required") # already an enum member @@ -27,19 +33,41 @@ def _check_enum(enum_cls: any, value: any, field_name: str) -> any: raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}") +def _allowed_columns(model: type[models.SQLModel]) -> set[str]: + tbl = cast("models.SQLModel", model).__table__ # type: ignore[attr-defined] + return {c.name for c in tbl.columns} + + +AnyModel = Any + + +def _data_to_dict(data: AnyModel) -> dict[str, AnyModel]: + if isinstance(data, BaseModel): + return data.model_dump(exclude_unset=True) + if hasattr(data, "dict"): + return data.dict(exclude_unset=True) + return dict(data) + + # Trades -def create_trade(session: Session, trade_data: Mapping) -> models.Trades: - if hasattr(trade_data, "dict"): - data = trade_data.dict(exclude_unset=True) - else: - data = dict(trade_data) - allowed = {c.name for c in models.Trades.__table__.columns} +def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> models.Trades: + data = _data_to_dict(trade_data) + allowed = _allowed_columns(models.Trades) payload = {k: v for k, v in data.items() if k in allowed} cycle_id = payload.get("cycle_id") if "symbol" not in payload: raise ValueError("symbol is required") if "exchange_id" not in payload and cycle_id is None: raise ValueError("exchange_id is required when no cycle is attached") + # If an exchange_id is provided (and no cycle is attached), ensure the exchange exists + # and belongs to the same user as the trade (if user_id is provided). + if cycle_id is None and "exchange_id" in payload: + ex = session.get(models.Exchanges, payload["exchange_id"]) + if ex is None: + raise ValueError("exchange_id does not exist") + user_id = payload.get("user_id") + if user_id is not None and ex.user_id != user_id: + raise ValueError("exchange.user_id does not match trade.user_id") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") @@ -132,7 +160,7 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades] statement = select(models.Trades).where( models.Trades.user_id == user_id, ) - return session.exec(statement).all() + return list(session.exec(statement).all()) def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades: @@ -168,23 +196,17 @@ def invalidate_trade(session: Session, trade_id: int) -> models.Trades: return trade -def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping) -> models.Trades: +def replace_trade(session: Session, old_trade_id: int, new_trade_data: Mapping[str, Any] | BaseModel) -> models.Trades: invalidate_trade(session, old_trade_id) - if hasattr(new_trade_data, "dict"): - data = new_trade_data.dict(exclude_unset=True) - else: - data = dict(new_trade_data) + data = _data_to_dict(new_trade_data) data["replaced_by_trade_id"] = old_trade_id return create_trade(session, data) # Cycles -def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: - if hasattr(cycle_data, "dict"): - data = cycle_data.dict(exclude_unset=True) - else: - data = dict(cycle_data) - allowed = {c.name for c in models.Cycles.__table__.columns} +def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> models.Cycles: + data = _data_to_dict(cycle_data) + allowed = _allowed_columns(models.Cycles) payload = {k: v for k, v in data.items() if k in allowed} if "user_id" not in payload: raise ValueError("user_id is required") @@ -192,6 +214,12 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: raise ValueError("symbol is required") if "exchange_id" not in payload: raise ValueError("exchange_id is required") + # ensure the exchange exists and belongs to the same user + ex = session.get(models.Exchanges, payload["exchange_id"]) + if ex is None: + raise ValueError("exchange_id does not exist") + if ex.user_id != payload.get("user_id"): + raise ValueError("exchange.user_id does not match cycle.user_id") if "underlying_currency" not in payload: raise ValueError("underlying_currency is required") payload["underlying_currency"] = _check_enum(models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency") @@ -215,21 +243,26 @@ def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"} -def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> models.Cycles: +def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles: cycle: models.Cycles | None = session.get(models.Cycles, cycle_id) if cycle is None: raise ValueError("cycle_id does not exist") - if hasattr(update_data, "dict"): - data = update_data.dict(exclude_unset=True) - else: - data = dict(update_data) + data = _data_to_dict(update_data) - allowed = {c.name for c in models.Cycles.__table__.columns} + allowed = _allowed_columns(models.Cycles) for k, v in data.items(): if k in IMMUTABLE_CYCLE_FIELDS: raise ValueError(f"field {k!r} is immutable") if k not in allowed: continue + # If trying to change exchange_id, ensure the new exchange exists and belongs to + # the same user as the cycle. + if k == "exchange_id": + ex = session.get(models.Exchanges, v) + if ex is None: + raise ValueError("exchange_id does not exist") + if ex.user_id != cycle.user_id: + raise ValueError("exchange.user_id does not match cycle.user_id") if k == "underlying_currency": v = _check_enum(models.UnderlyingCurrency, v, "underlying_currency") # noqa: PLW2901 if k == "status": @@ -249,12 +282,9 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping) -> model IMMUTABLE_EXCHANGE_FIELDS = {"id"} -def create_exchange(session: Session, exchange_data: Mapping) -> models.Exchanges: - if hasattr(exchange_data, "dict"): - data = exchange_data.dict(exclude_unset=True) - else: - data = dict(exchange_data) - allowed = {c.name for c in models.Exchanges.__table__.columns} +def create_exchange(session: Session, exchange_data: Mapping[str, Any] | BaseModel) -> models.Exchanges: + data = _data_to_dict(exchange_data) + allowed = _allowed_columns(models.Exchanges) payload = {k: v for k, v in data.items() if k in allowed} if "name" not in payload: raise ValueError("name is required") @@ -284,25 +314,22 @@ def get_exchange_by_name_and_user_id(session: Session, name: str, user_id: int) def get_all_exchanges(session: Session) -> list[models.Exchanges]: statement = select(models.Exchanges) - return session.exec(statement).all() + return list(session.exec(statement).all()) def get_all_exchanges_by_user_id(session: Session, user_id: int) -> list[models.Exchanges]: statement = select(models.Exchanges).where( models.Exchanges.user_id == user_id, ) - return session.exec(statement).all() + return list(session.exec(statement).all()) -def update_exchange(session: Session, exchange_id: int, update_data: Mapping) -> models.Exchanges: +def update_exchange(session: Session, exchange_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Exchanges: exchange: models.Exchanges | None = session.get(models.Exchanges, exchange_id) if exchange is None: raise ValueError("exchange_id does not exist") - if hasattr(update_data, "dict"): - data = update_data.dict(exclude_unset=True) - else: - data = dict(update_data) - allowed = {c.name for c in models.Exchanges.__table__.columns} + data = _data_to_dict(update_data) + allowed = _allowed_columns(models.Exchanges) for k, v in data.items(): if k in IMMUTABLE_EXCHANGE_FIELDS: raise ValueError(f"field {k!r} is immutable") @@ -334,12 +361,9 @@ def delete_exchange(session: Session, exchange_id: int) -> None: IMMUTABLE_USER_FIELDS = {"id", "username", "created_at"} -def create_user(session: Session, user_data: Mapping) -> models.Users: - if hasattr(user_data, "dict"): - data = user_data.dict(exclude_unset=True) - else: - data = dict(user_data) - allowed = {c.name for c in models.Users.__table__.columns} +def create_user(session: Session, user_data: Mapping[str, Any] | BaseModel) -> models.Users: + data = _data_to_dict(user_data) + allowed = _allowed_columns(models.Users) payload = {k: v for k, v in data.items() if k in allowed} if "username" not in payload: raise ValueError("username is required") @@ -368,15 +392,12 @@ def get_user_by_username(session: Session, username: str) -> models.Users | None return session.exec(statement).first() -def update_user(session: Session, user_id: int, update_data: Mapping) -> models.Users: +def update_user(session: Session, user_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Users: user: models.Users | None = session.get(models.Users, user_id) if user is None: raise ValueError("user_id does not exist") - if hasattr(update_data, "dict"): - data = update_data.dict(exclude_unset=True) - else: - data = dict(update_data) - allowed = {c.name for c in models.Users.__table__.columns} + data = _data_to_dict(update_data) + allowed = _allowed_columns(models.Users) for k, v in data.items(): if k in IMMUTABLE_USER_FIELDS: raise ValueError(f"field {k!r} is immutable") @@ -405,10 +426,11 @@ def create_login_session( user: models.Users | None = session.get(models.Users, user_id) if user is None: raise ValueError("user_id does not exist") + user_id_val = cast("int", user.id) now = datetime.now(timezone.utc) expires_at = now + timedelta(seconds=session_length_seconds) s = models.Sessions( - user_id=user.id, + user_id=user_id_val, session_token_hash=session_token_hash, created_at=now, expires_at=expires_at, @@ -449,7 +471,7 @@ def get_login_session_by_token_hash(session: Session, session_token_hash: str) - IMMUTABLE_SESSION_FIELDS = {"id", "user_id", "session_token_hash", "created_at"} -def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping) -> models.Sessions | None: +def update_login_session(session: Session, session_token_hashed: str, update_session: Mapping[str, Any] | BaseModel) -> models.Sessions | None: login_session: models.Sessions | None = session.exec( select(models.Sessions).where( models.Sessions.session_token_hash == session_token_hashed, @@ -458,11 +480,8 @@ def update_login_session(session: Session, session_token_hashed: str, update_ses ).first() if login_session is None: return None - if hasattr(update_session, "dict"): - data = update_session.dict(exclude_unset=True) - else: - data = dict(update_session) - allowed = {c.name for c in models.Sessions.__table__.columns} + data = _data_to_dict(update_session) + allowed = _allowed_columns(models.Sessions) for k, v in data.items(): if k in allowed and k not in IMMUTABLE_SESSION_FIELDS: setattr(login_session, k, v) diff --git a/backend/trading_journal/db_migration.py b/backend/trading_journal/db_migration.py index e3766a2..8a63250 100644 --- a/backend/trading_journal/db_migration.py +++ b/backend/trading_journal/db_migration.py @@ -23,11 +23,11 @@ def _mig_0_1(engine: Engine) -> None: SQLModel.metadata.create_all( bind=engine, tables=[ - models_v1.Trades.__table__, - models_v1.Cycles.__table__, - models_v1.Users.__table__, - models_v1.Sessions.__table__, - models_v1.Exchanges.__table__, + models_v1.Trades.__table__, # type: ignore[attr-defined] + models_v1.Cycles.__table__, # type: ignore[attr-defined] + models_v1.Users.__table__, # type: ignore[attr-defined] + models_v1.Sessions.__table__, # type: ignore[attr-defined] + models_v1.Exchanges.__table__, # type: ignore[attr-defined] ], ) diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 9bdc57a..a7d364d 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -64,7 +64,7 @@ class FundingSource(str, Enum): class Trades(SQLModel, table=True): - __tablename__ = "trades" + __tablename__ = "trades" # type: ignore[attr-defined] __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),) id: int | None = Field(default=None, primary_key=True) @@ -95,7 +95,7 @@ class Trades(SQLModel, table=True): class Cycles(SQLModel, table=True): - __tablename__ = "cycles" + __tablename__ = "cycles" # type: ignore[attr-defined] __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),) id: int | None = Field(default=None, primary_key=True) @@ -116,7 +116,7 @@ class Cycles(SQLModel, table=True): class Exchanges(SQLModel, table=True): - __tablename__ = "exchanges" + __tablename__ = "exchanges" # type: ignore[attr-defined] __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) @@ -128,7 +128,7 @@ class Exchanges(SQLModel, table=True): class Users(SQLModel, table=True): - __tablename__ = "users" + __tablename__ = "users" # type: ignore[attr-defined] id: int | None = Field(default=None, primary_key=True) # unique=True already creates an index; no need to also set index=True username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) @@ -139,7 +139,7 @@ class Users(SQLModel, table=True): class Sessions(SQLModel, table=True): - __tablename__ = "sessions" + __tablename__ = "sessions" # type: ignore[attr-defined] id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 9bdc57a..a7d364d 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -64,7 +64,7 @@ class FundingSource(str, Enum): class Trades(SQLModel, table=True): - __tablename__ = "trades" + __tablename__ = "trades" # type: ignore[attr-defined] __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_trades_user_friendly_name"),) id: int | None = Field(default=None, primary_key=True) @@ -95,7 +95,7 @@ class Trades(SQLModel, table=True): class Cycles(SQLModel, table=True): - __tablename__ = "cycles" + __tablename__ = "cycles" # type: ignore[attr-defined] __table_args__ = (UniqueConstraint("user_id", "friendly_name", name="uq_cycles_user_friendly_name"),) id: int | None = Field(default=None, primary_key=True) @@ -116,7 +116,7 @@ class Cycles(SQLModel, table=True): class Exchanges(SQLModel, table=True): - __tablename__ = "exchanges" + __tablename__ = "exchanges" # type: ignore[attr-defined] __table_args__ = (UniqueConstraint("user_id", "name", name="uq_exchanges_user_name"),) id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) @@ -128,7 +128,7 @@ class Exchanges(SQLModel, table=True): class Users(SQLModel, table=True): - __tablename__ = "users" + __tablename__ = "users" # type: ignore[attr-defined] id: int | None = Field(default=None, primary_key=True) # unique=True already creates an index; no need to also set index=True username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) @@ -139,7 +139,7 @@ class Users(SQLModel, table=True): class Sessions(SQLModel, table=True): - __tablename__ = "sessions" + __tablename__ = "sessions" # type: ignore[attr-defined] id: int | None = Field(default=None, primary_key=True) user_id: int = Field(foreign_key="users.id", nullable=False, index=True) session_token_hash: str = Field(sa_column=Column(Text, nullable=False, unique=True)) diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 881fe53..a9ba8c5 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -2,21 +2,24 @@ from __future__ import annotations import logging from datetime import datetime, timedelta, timezone -from typing import Callable +from typing import TYPE_CHECKING, cast from fastapi import Request, Response, status from fastapi.responses import JSONResponse -from sqlmodel import Session -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint import settings from trading_journal import crud, security -from trading_journal.db import Database from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead -from trading_journal.models import Sessions SessionsCreate.model_rebuild() +if TYPE_CHECKING: + from sqlmodel import Session + + from trading_journal.db import Database + from trading_journal.models import Sessions + EXCEPT_PATHS = [ f"{settings.settings.api_base}/status", @@ -28,7 +31,7 @@ logger = logging.getLogger(__name__) class AuthMiddleWare(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: # noqa: PLR0911 + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: PLR0911 if request.url.path in EXCEPT_PATHS: return await call_next(request) @@ -51,12 +54,12 @@ class AuthMiddleWare(BaseHTTPMiddleware): with db_factory.get_session_ctx_manager() as request_session: hashed_token = security.hash_session_token_sha256(token) request.state.db_session = request_session - login_session: Sessions | None = crud.get_login_session_by_token_hash(request.state.db_session, hashed_token) + login_session: Sessions | None = crud.get_login_session_by_token_hash(request_session, hashed_token) if not login_session: return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) session_expires_utc = login_session.expires_at.replace(tzinfo=timezone.utc) if session_expires_utc < datetime.now(timezone.utc): - crud.delete_login_session(request.state.db_session, login_session) + crud.delete_login_session(request_session, login_session.session_token_hash) return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) if login_session.user.is_active is False: return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Unauthorized"}) @@ -72,7 +75,7 @@ class AuthMiddleWare(BaseHTTPMiddleware): ) user_id = login_session.user_id request.state.user_id = user_id - crud.update_login_session(request.state.db_session, hashed_token, update_session=updated_session) + crud.update_login_session(request_session, hashed_token, update_session=updated_session) except Exception: logger.exception("Failed to authenticate user: \n") return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error"}) @@ -106,7 +109,7 @@ def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: # prefer pydantic's from_orm if DTO supports orm_mode user = UserRead.model_validate(user) except Exception as e: - logger.exception("Failed to convert user to UserRead: %s", e) + logger.exception("Failed to convert user to UserRead: ") raise ServiceError("Failed to convert user to UserRead") from e except Exception as e: logger.exception("Failed to create user:") @@ -118,6 +121,7 @@ def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[ user = crud.get_user_by_username(db_session, user_in.username) if not user: return None + user_id_val = cast("int", user.id) if not security.verify_password(user_in.password, user.password_hash): return None @@ -127,7 +131,7 @@ def authenticate_user_service(db_session: Session, user_in: UserLogin) -> tuple[ try: session = crud.create_login_session( session=db_session, - user_id=user.id, + user_id=user_id_val, session_token_hash=token_hashed, session_length_seconds=settings.settings.session_expiry_seconds, ) -- 2.49.1 From a6592bd1408a10839c966b9442ea2d8347b8f005 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 23 Sep 2025 23:35:15 +0200 Subject: [PATCH 10/18] wip --- backend/app.py | 44 ++++++++++-- backend/trading_journal/dto.py | 111 +++++++++++++++++------------ backend/trading_journal/service.py | 58 +++++++++++++-- 3 files changed, 160 insertions(+), 53 deletions(-) diff --git a/backend/app.py b/backend/app.py index 4c6f61d..aae91b7 100644 --- a/backend/app.py +++ b/backend/app.py @@ -7,11 +7,12 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING from fastapi import FastAPI, HTTPException, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse, Response import settings from trading_journal import db, service -from trading_journal.dto import ExchangesBase, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead +from trading_journal.dto import CycleBase, ExchangesBase, ExchangesRead, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -119,10 +120,10 @@ async def create_exchange(request: Request, exchange_data: ExchangesBase) -> Res @app.get(f"{settings.settings.api_base}/exchanges") -async def get_exchanges(request: Request) -> list[ExchangesBase]: +async def get_exchanges(request: Request) -> list[ExchangesRead]: db_factory: Database = request.app.state.db_factory - def sync_work() -> list[ExchangesBase]: + def sync_work() -> list[ExchangesRead]: with db_factory.get_session_ctx_manager() as db: return service.get_exchanges_by_user_service(db, request.state.user_id) @@ -133,4 +134,39 @@ async def get_exchanges(request: Request) -> list[ExchangesBase]: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e -# Trade +@app.patch(f"{settings.settings.api_base}/exchanges/{{exchange_id}}") +async def update_exchange(request: Request, exchange_id: int, exchange_data: ExchangesBase) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> ExchangesBase: + with db_factory.get_session_ctx_manager() as db: + return service.update_exchanges_service(db, request.state.user_id, exchange_id, exchange_data.name, exchange_data.notes) + + try: + exchange = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=exchange.model_dump()) + except service.ExchangeNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except service.ExchangeAlreadyExistsError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to update exchange: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +# Cycle +@app.post(f"{settings.settings.api_base}/cycles") +async def create_cycle(request: Request, cycle_data: CycleBase) -> Response: + return JSONResponse(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, content="Not supported.") + db_factory: Database = request.app.state.db_factory + + def sync_work() -> CycleBase: + with db_factory.get_session_ctx_manager() as db: + return service.create_cycle_service(db, request.state.user_id, cycle_data) + + try: + cycle = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(cycle)) + except Exception as e: + logger.exception("Failed to create cycle: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 5ea709c..7b377f9 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -1,55 +1,11 @@ from __future__ import annotations from datetime import date, datetime # noqa: TC003 -from typing import TYPE_CHECKING from pydantic import BaseModel from sqlmodel import SQLModel -if TYPE_CHECKING: - from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency - - -class ExchangesBase(SQLModel): - name: str - notes: str | None = None - - -class ExchangesCreate(ExchangesBase): - user_id: int - - -class TradeBase(SQLModel): - user_id: int - friendly_name: str | None - symbol: str - exchange: str - underlying_currency: UnderlyingCurrency - trade_type: TradeType - trade_strategy: TradeStrategy - trade_date: date - trade_time_utc: datetime - quantity: int - price_cents: int - gross_cash_flow_cents: int - commission_cents: int - net_cash_flow_cents: int - notes: str | None - cycle_id: int | None = None - - -class TradeCreate(TradeBase): - expiry_date: date | None = None - strike_price_cents: int | None = None - is_invalidated: bool = False - invalidated_at: datetime | None = None - replaced_by_trade_id: int | None = None - - -class TradeRead(TradeBase): - id: int - is_invalidated: bool - invalidated_at: datetime | None +from trading_journal.models import TradeStrategy, TradeType, UnderlyingCurrency # noqa: TC001 class UserBase(SQLModel): @@ -91,3 +47,68 @@ class SessionsUpdate(SQLModel): last_seen_at: datetime | None = None last_used_ip: str | None = None user_agent: str | None = None + + +class ExchangesBase(SQLModel): + name: str + notes: str | None = None + + +class ExchangesCreate(ExchangesBase): + user_id: int + + +class ExchangesRead(ExchangesBase): + id: int + + +class CycleBase(SQLModel): + friendly_name: str | None = None + symbol: str + exchange_id: int + underlying_currency: UnderlyingCurrency + status: str + start_date: date + end_date: date | None = None + funding_source: str | None = None + capital_exposure_cents: int | None = None + loan_amount_cents: int | None = None + loan_interest_rate_bps: int | None = None + trades: list[TradeRead] | None = None + + +class CycleCreate(CycleBase): + user_id: int + + +class TradeBase(SQLModel): + user_id: int + friendly_name: str | None + symbol: str + exchange: str + underlying_currency: UnderlyingCurrency + trade_type: TradeType + trade_strategy: TradeStrategy + trade_date: date + trade_time_utc: datetime + quantity: int + price_cents: int + gross_cash_flow_cents: int + commission_cents: int + net_cash_flow_cents: int + notes: str | None + cycle_id: int | None = None + + +class TradeCreate(TradeBase): + expiry_date: date | None = None + strike_price_cents: int | None = None + is_invalidated: bool = False + invalidated_at: datetime | None = None + replaced_by_trade_id: int | None = None + + +class TradeRead(TradeBase): + id: int + is_invalidated: bool + invalidated_at: datetime | None diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index a9ba8c5..5650515 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -10,9 +10,21 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoin import settings from trading_journal import crud, security -from trading_journal.dto import ExchangesBase, ExchangesCreate, SessionsCreate, SessionsUpdate, UserCreate, UserLogin, UserRead +from trading_journal.dto import ( + CycleBase, + CycleCreate, + ExchangesBase, + ExchangesCreate, + ExchangesRead, + SessionsCreate, + SessionsUpdate, + UserCreate, + UserLogin, + UserRead, +) SessionsCreate.model_rebuild() +CycleBase.model_rebuild() if TYPE_CHECKING: from sqlmodel import Session @@ -95,6 +107,11 @@ class ExchangeAlreadyExistsError(ServiceError): pass +class ExchangeNotFoundError(ServiceError): + pass + + +# User service def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): raise UserAlreadyExistsError("username already exists") @@ -156,7 +173,7 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes: try: exchange_dto = ExchangesCreate.model_validate(exchange) except Exception as e: - logger.exception("Failed to convert exchange to ExchangesCreate: ") + logger.exception("Failed to convert exchange to ExchangesCreate:") raise ServiceError("Failed to convert exchange to ExchangesCreate") from e except Exception as e: logger.exception("Failed to create exchange:") @@ -164,9 +181,42 @@ def create_exchange_service(db_session: Session, user_id: int, name: str, notes: return exchange_dto -def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesBase]: +def get_exchanges_by_user_service(db_session: Session, user_id: int) -> list[ExchangesRead]: exchanges = crud.get_all_exchanges_by_user_id(db_session, user_id) - return [ExchangesBase.model_validate(exchange) for exchange in exchanges] + return [ExchangesRead.model_validate(exchange) for exchange in exchanges] + + +def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int, name: str | None, notes: str | None) -> ExchangesBase: + existing_exchange = crud.get_exchange_by_id(db_session, exchange_id) + if not existing_exchange: + raise ExchangeNotFoundError("Exchange not found") + if existing_exchange.user_id != user_id: + raise ExchangeNotFoundError("Exchange not found") + + if name: + other_exchange = crud.get_exchange_by_name_and_user_id(db_session, name, user_id) + if other_exchange and other_exchange.id != existing_exchange.id: + raise ExchangeAlreadyExistsError("Another exchange with the same name already exists for this user") + + exchange_data = ExchangesBase( + name=name or existing_exchange.name, + notes=notes or existing_exchange.notes, + ) + try: + exchange = crud.update_exchange(db_session, cast("int", existing_exchange.id), update_data=exchange_data) + except Exception as e: + logger.exception("Failed to update exchange: \n") + raise ServiceError("Failed to update exchange") from e + return ExchangesBase.model_validate(exchange) + + +# Cycle Service +def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleBase: + cycle_data_dict = cycle_data.model_dump() + cycle_data_dict["user_id"] = user_id + cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict) + crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id) + return cycle_data def get_trades_service(db_session: Session, user_id: int) -> list: -- 2.49.1 From cf6c8264684574428edb203eeb77ef66b1c7fa63 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 24 Sep 2025 10:44:32 +0200 Subject: [PATCH 11/18] use utils module --- backend/utils/db_migration.py | 7 +++++++ backend/utils/db_mirgration.py | 13 ------------- 2 files changed, 7 insertions(+), 13 deletions(-) create mode 100644 backend/utils/db_migration.py delete mode 100644 backend/utils/db_mirgration.py diff --git a/backend/utils/db_migration.py b/backend/utils/db_migration.py new file mode 100644 index 0000000..0fc58f8 --- /dev/null +++ b/backend/utils/db_migration.py @@ -0,0 +1,7 @@ +from sqlmodel import create_engine + +import settings +from trading_journal import db_migration + +db_engine = create_engine(settings.settings.database_url, echo=True) +db_migration.run_migrations(db_engine) diff --git a/backend/utils/db_mirgration.py b/backend/utils/db_mirgration.py deleted file mode 100644 index 1103349..0000000 --- a/backend/utils/db_mirgration.py +++ /dev/null @@ -1,13 +0,0 @@ -import sys -from pathlib import Path - -from sqlmodel import create_engine - -project_parent = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(project_parent)) - -import settings # noqa: E402 -from trading_journal import db_migration # noqa: E402 - -db_engine = create_engine(settings.settings.database_url, echo=True) -db_migration.run_migrations(db_engine) -- 2.49.1 From 80fc405bf6831d94be95d7c0817371f20b46124e Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 24 Sep 2025 17:33:27 +0200 Subject: [PATCH 12/18] Almost finish basic functionalities --- backend/.vscode/launch.json | 10 +- backend/app.py | 148 ++++++++++++++++++++++++++- backend/tests/test_crud.py | 94 +++++++++++++++++ backend/tests/test_db_migration.py | 3 +- backend/trading_journal/crud.py | 35 +++++-- backend/trading_journal/dto.py | 50 ++++++--- backend/trading_journal/models.py | 3 +- backend/trading_journal/models_v1.py | 3 +- backend/trading_journal/service.py | 144 ++++++++++++++++++++++++-- 9 files changed, 455 insertions(+), 35 deletions(-) diff --git a/backend/.vscode/launch.json b/backend/.vscode/launch.json index 93b32ee..929a971 100644 --- a/backend/.vscode/launch.json +++ b/backend/.vscode/launch.json @@ -13,10 +13,14 @@ "app:app", "--host=0.0.0.0", "--reload", - "--port=5000" + "--port=18881" ], "jinja": true, - "autoStartBrowser": true + "autoStartBrowser": false, + "env": { + "CONFIG_FILE": "devsettings.yaml" + }, + "console": "integratedTerminal" } ] -} +} \ No newline at end of file diff --git a/backend/app.py b/backend/app.py index aae91b7..e0e2799 100644 --- a/backend/app.py +++ b/backend/app.py @@ -12,7 +12,22 @@ from fastapi.responses import JSONResponse, Response import settings from trading_journal import db, service -from trading_journal.dto import CycleBase, ExchangesBase, ExchangesRead, SessionsBase, SessionsCreate, UserCreate, UserLogin, UserRead +from trading_journal.dto import ( + CycleBase, + CycleRead, + CycleUpdate, + ExchangesBase, + ExchangesRead, + SessionsBase, + SessionsCreate, + TradeCreate, + TradeFriendlyNameUpdate, + TradeNoteUpdate, + TradeRead, + UserCreate, + UserLogin, + UserRead, +) if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -170,3 +185,134 @@ async def create_cycle(request: Request, cycle_data: CycleBase) -> Response: except Exception as e: logger.exception("Failed to create cycle: \n") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.get(f"{settings.settings.api_base}/cycles/{{cycle_id}}") +async def get_cycle_by_id(request: Request, cycle_id: int) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> CycleBase: + with db_factory.get_session_ctx_manager() as db: + return service.get_cycle_by_id_service(db, request.state.user_id, cycle_id) + + try: + cycle = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle)) + except service.CycleNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to get cycle by id: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.get(f"{settings.settings.api_base}/cycles/user/{{user_id}}") +async def get_cycles_by_user(request: Request, user_id: int) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> list[CycleRead]: + with db_factory.get_session_ctx_manager() as db: + return service.get_cycles_by_user_service(db, user_id) + + try: + cycles = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycles)) + except Exception as e: + logger.exception("Failed to get cycles by user: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.patch(f"{settings.settings.api_base}/cycles") +async def update_cycle(request: Request, cycle_data: CycleUpdate) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> CycleRead: + with db_factory.get_session_ctx_manager() as db: + return service.update_cycle_service(db, request.state.user_id, cycle_data) + + try: + cycle = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(cycle)) + except service.InvalidCycleDataError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except service.CycleNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to update cycle: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.post(f"{settings.settings.api_base}/trades") +async def create_trade(request: Request, trade_data: TradeCreate) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> TradeRead: + with db_factory.get_session_ctx_manager() as db: + return service.create_trade_service(db, request.state.user_id, trade_data) + + try: + trade = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=jsonable_encoder(trade)) + except service.InvalidTradeDataError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to create trade: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.get(f"{settings.settings.api_base}/trades/{{trade_id}}") +async def get_trade_by_id(request: Request, trade_id: int) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> TradeRead: + with db_factory.get_session_ctx_manager() as db: + return service.get_trade_by_id_service(db, request.state.user_id, trade_id) + + try: + trade = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade)) + except service.TradeNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to get trade by id: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.patch(f"{settings.settings.api_base}/trades/friendlyname") +async def update_trade_friendly_name(request: Request, friendly_name_update: TradeFriendlyNameUpdate) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> TradeRead: + with db_factory.get_session_ctx_manager() as db: + return service.update_trade_friendly_name_service( + db, + request.state.user_id, + friendly_name_update.id, + friendly_name_update.friendly_name, + ) + + try: + trade = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade)) + except service.TradeNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to update trade friendly name: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + + +@app.patch(f"{settings.settings.api_base}/trades/notes") +async def update_trade_note(request: Request, note_update: TradeNoteUpdate) -> Response: + db_factory: Database = request.app.state.db_factory + + def sync_work() -> TradeRead: + with db_factory.get_session_ctx_manager() as db: + return service.update_trade_note_service(db, request.state.user_id, note_update.id, note_update.notes) + + try: + trade = await asyncio.to_thread(sync_work) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder(trade)) + except service.TradeNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except Exception as e: + logger.exception("Failed to update trade note: \n") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 0620fed..d7ae484 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -172,6 +172,51 @@ def test_create_trade_success_with_cycle(session: Session) -> None: assert actual_trade.trade_type == trade_data["trade_type"] assert actual_trade.trade_strategy == trade_data["trade_strategy"] assert actual_trade.quantity == trade_data["quantity"] + assert actual_trade.quantity_multiplier == 1 + assert actual_trade.price_cents == trade_data["price_cents"] + assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"] + assert actual_trade.commission_cents == trade_data["commission_cents"] + assert actual_trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"] + assert actual_trade.cycle_id == trade_data["cycle_id"] + + +def test_create_trade_with_custom_multipler(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + trade_data = { + "user_id": user_id, + "friendly_name": "Test Trade with Multiplier", + "symbol": "AAPL", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, + "trade_time_utc": datetime.now(timezone.utc), + "quantity": 10, + "quantity_multiplier": 100, + "price_cents": 15000, + "gross_cash_flow_cents": -1500000, + "commission_cents": 50000, + "net_cash_flow_cents": -1550000, + "cycle_id": cycle_id, + } + + trade = crud.create_trade(session, trade_data) + assert trade.id is not None + assert trade.user_id == user_id + assert trade.cycle_id == cycle_id + session.refresh(trade) + + actual_trade = session.get(models.Trades, trade.id) + assert actual_trade is not None + assert actual_trade.friendly_name == trade_data["friendly_name"] + assert actual_trade.symbol == trade_data["symbol"] + assert actual_trade.underlying_currency == trade_data["underlying_currency"] + assert actual_trade.trade_type == trade_data["trade_type"] + assert actual_trade.trade_strategy == trade_data["trade_strategy"] + assert actual_trade.quantity == trade_data["quantity"] + assert actual_trade.quantity_multiplier == trade_data["quantity_multiplier"] assert actual_trade.price_cents == trade_data["price_cents"] assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"] assert actual_trade.commission_cents == trade_data["commission_cents"] @@ -194,6 +239,9 @@ def test_create_trade_with_auto_created_cycle(session: Session) -> None: "trade_time_utc": datetime.now(timezone.utc), "quantity": 5, "price_cents": 15500, + "gross_cash_flow_cents": -77500, + "commission_cents": 300, + "net_cash_flow_cents": -77800, } trade = crud.create_trade(session, trade_data) @@ -405,6 +453,24 @@ def test_get_trades_by_user_id(session: Session) -> None: assert friendly_names == {"Trade One", "Trade Two"} +def test_update_trade_friendly_name(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + trade_id = make_trade(session, user_id, cycle_id) + + new_friendly_name = "Updated Trade Name" + updated_trade = crud.update_trade_friendly_name(session, trade_id, new_friendly_name) + assert updated_trade is not None + assert updated_trade.id == trade_id + assert updated_trade.friendly_name == new_friendly_name + + session.refresh(updated_trade) + actual_trade = session.get(models.Trades, trade_id) + assert actual_trade is not None + assert actual_trade.friendly_name == new_friendly_name + + def test_update_trade_note(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) @@ -457,6 +523,9 @@ def test_replace_trade(session: Session) -> None: "trade_time_utc": datetime.now(timezone.utc), "quantity": 20, "price_cents": 25000, + "gross_cash_flow_cents": -500000, + "commission_cents": 1000, + "net_cash_flow_cents": -501000, } new_trade = crud.replace_trade(session, old_trade_id, new_trade_data) @@ -516,6 +585,31 @@ def test_create_cycle(session: Session) -> None: assert actual_cycle.start_date == cycle_data["start_date"] +def test_get_cycle_by_id(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id, friendly_name="Cycle to Get") + cycle = crud.get_cycle_by_id(session, cycle_id) + assert cycle is not None + assert cycle.id == cycle_id + assert cycle.friendly_name == "Cycle to Get" + assert cycle.user_id == user_id + + +def test_get_cycles_by_user_id(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_names = ["Cycle One", "Cycle Two", "Cycle Three"] + for name in cycle_names: + make_cycle(session, user_id, exchange_id, friendly_name=name) + + cycles = crud.get_cycles_by_user_id(session, user_id) + assert len(cycles) == len(cycle_names) + fetched_names = {cycle.friendly_name for cycle in cycles} + for name in cycle_names: + assert name in fetched_names + + def test_update_cycle(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index 343214b..042bb54 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -42,7 +42,7 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "funding_source": ("TEXT", 0, 0), "capital_exposure_cents": ("INTEGER", 0, 0), "loan_amount_cents": ("INTEGER", 0, 0), - "loan_interest_rate_bps": ("INTEGER", 0, 0), + "loan_interest_rate_tenth_bps": ("INTEGER", 0, 0), "start_date": ("DATE", 1, 0), "end_date": ("DATE", 0, 0), }, @@ -60,6 +60,7 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "expiry_date": ("DATE", 0, 0), "strike_price_cents": ("INTEGER", 0, 0), "quantity": ("INTEGER", 1, 0), + "quantity_multiplier": ("INTEGER", 1, 0), "price_cents": ("INTEGER", 1, 0), "gross_cash_flow_cents": ("INTEGER", 1, 0), "commission_cents": ("INTEGER", 1, 0), diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 1918b5d..da8e237 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -90,13 +90,10 @@ def create_trade(session: Session, trade_data: Mapping[str, Any] | BaseModel) -> raise ValueError("price_cents is required") if "commission_cents" not in payload: payload["commission_cents"] = 0 - quantity: int = payload["quantity"] - price_cents: int = payload["price_cents"] - commission_cents: int = payload["commission_cents"] if "gross_cash_flow_cents" not in payload: - payload["gross_cash_flow_cents"] = -quantity * price_cents + raise ValueError("gross_cash_flow_cents is required") if "net_cash_flow_cents" not in payload: - payload["net_cash_flow_cents"] = payload["gross_cash_flow_cents"] - commission_cents + raise ValueError("net_cash_flow_cents is required") # If no cycle_id provided, create Cycle instance but don't call create_cycle() created_cycle = None @@ -163,6 +160,21 @@ def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades] return list(session.exec(statement).all()) +def update_trade_friendly_name(session: Session, trade_id: int, friendly_name: str) -> models.Trades: + trade: models.Trades | None = session.get(models.Trades, trade_id) + if trade is None: + raise ValueError("trade_id does not exist") + trade.friendly_name = friendly_name + session.add(trade) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("update_trade_friendly_name integrity error") from e + session.refresh(trade) + return trade + + def update_trade_note(session: Session, trade_id: int, note: str) -> models.Trades: trade: models.Trades | None = session.get(models.Trades, trade_id) if trade is None: @@ -240,7 +252,18 @@ def create_cycle(session: Session, cycle_data: Mapping[str, Any] | BaseModel) -> return c -IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date", "created_at"} +IMMUTABLE_CYCLE_FIELDS = {"id", "user_id", "start_date"} + + +def get_cycle_by_id(session: Session, cycle_id: int) -> models.Cycles | None: + return session.get(models.Cycles, cycle_id) + + +def get_cycles_by_user_id(session: Session, user_id: int) -> list[models.Cycles]: + statement = select(models.Cycles).where( + models.Cycles.user_id == user_id, + ) + return list(session.exec(statement).all()) def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] | BaseModel) -> models.Cycles: diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 7b377f9..1851b0d 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -64,43 +64,53 @@ class ExchangesRead(ExchangesBase): class CycleBase(SQLModel): friendly_name: str | None = None - symbol: str - exchange_id: int - underlying_currency: UnderlyingCurrency status: str - start_date: date end_date: date | None = None funding_source: str | None = None capital_exposure_cents: int | None = None loan_amount_cents: int | None = None loan_interest_rate_bps: int | None = None trades: list[TradeRead] | None = None + exchange: ExchangesRead | None = None class CycleCreate(CycleBase): user_id: int + symbol: str + exchange_id: int + underlying_currency: UnderlyingCurrency + start_date: date + + +class CycleUpdate(CycleBase): + id: int + + +class CycleRead(CycleCreate): + id: int class TradeBase(SQLModel): - user_id: int - friendly_name: str | None + friendly_name: str | None = None symbol: str - exchange: str + exchange_id: int underlying_currency: UnderlyingCurrency trade_type: TradeType trade_strategy: TradeStrategy trade_date: date - trade_time_utc: datetime quantity: int price_cents: int - gross_cash_flow_cents: int commission_cents: int - net_cash_flow_cents: int - notes: str | None + notes: str | None = None cycle_id: int | None = None class TradeCreate(TradeBase): + user_id: int | None = None + trade_time_utc: datetime | None = None + gross_cash_flow_cents: int | None = None + net_cash_flow_cents: int | None = None + quantity_multiplier: int = 1 expiry_date: date | None = None strike_price_cents: int | None = None is_invalidated: bool = False @@ -108,7 +118,19 @@ class TradeCreate(TradeBase): replaced_by_trade_id: int | None = None -class TradeRead(TradeBase): +class TradeNoteUpdate(BaseModel): id: int - is_invalidated: bool - invalidated_at: datetime | None + notes: str | None = None + + +class TradeFriendlyNameUpdate(BaseModel): + id: int + friendly_name: str + + +class TradeRead(TradeCreate): + id: int + + +SessionsCreate.model_rebuild() +CycleBase.model_rebuild() diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index a7d364d..e8dc281 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -82,6 +82,7 @@ class Trades(SQLModel, table=True): expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int = Field(sa_column=Column(Integer, nullable=False)) + quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1) price_cents: int = Field(sa_column=Column(Integer, nullable=False)) gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) @@ -109,7 +110,7 @@ class Cycles(SQLModel, table=True): funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True) - loan_interest_rate_bps: int | None = Field(default=None, nullable=True) + loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) trades: list["Trades"] = Relationship(back_populates="cycle") diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index a7d364d..e8dc281 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -82,6 +82,7 @@ class Trades(SQLModel, table=True): expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int = Field(sa_column=Column(Integer, nullable=False)) + quantity_multiplier: int = Field(sa_column=Column(Integer, nullable=False), default=1) price_cents: int = Field(sa_column=Column(Integer, nullable=False)) gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) @@ -109,7 +110,7 @@ class Cycles(SQLModel, table=True): funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True) - loan_interest_rate_bps: int | None = Field(default=None, nullable=True) + loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) trades: list["Trades"] = Relationship(back_populates="cycle") diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 5650515..07a3b19 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -13,19 +13,20 @@ from trading_journal import crud, security from trading_journal.dto import ( CycleBase, CycleCreate, + CycleRead, + CycleUpdate, ExchangesBase, ExchangesCreate, ExchangesRead, SessionsCreate, SessionsUpdate, + TradeCreate, + TradeRead, UserCreate, UserLogin, UserRead, ) -SessionsCreate.model_rebuild() -CycleBase.model_rebuild() - if TYPE_CHECKING: from sqlmodel import Session @@ -111,6 +112,22 @@ class ExchangeNotFoundError(ServiceError): pass +class CycleNotFoundError(ServiceError): + pass + + +class TradeNotFoundError(ServiceError): + pass + + +class InvalidTradeDataError(ServiceError): + pass + + +class InvalidCycleDataError(ServiceError): + pass + + # User service def register_user_service(db_session: Session, user_in: UserCreate) -> UserRead: if crud.get_user_by_username(db_session, user_in.username): @@ -211,13 +228,124 @@ def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int # Cycle Service -def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleBase: +def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead: cycle_data_dict = cycle_data.model_dump() cycle_data_dict["user_id"] = user_id cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict) - crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id) - return cycle_data + created_cycle = crud.create_cycle(db_session, cycle_data=cycle_data_with_user_id) + return CycleRead.model_validate(created_cycle) -def get_trades_service(db_session: Session, user_id: int) -> list: - return crud.get_trades_by_user_id(db_session, user_id) +def get_cycle_by_id_service(db_session: Session, user_id: int, cycle_id: int) -> CycleRead: + cycle = crud.get_cycle_by_id(db_session, cycle_id) + if not cycle: + raise CycleNotFoundError("Cycle not found") + if cycle.user_id != user_id: + raise CycleNotFoundError("Cycle not found") + return CycleRead.model_validate(cycle) + + +def get_cycles_by_user_service(db_session: Session, user_id: int) -> list[CycleRead]: + cycles = crud.get_cycles_by_user_id(db_session, user_id) + return [CycleRead.model_validate(cycle) for cycle in cycles] + + +def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: + if cycle_data.status == "CLOSED" and cycle_data.end_date is None: + return False, "end_date is required when status is CLOSED" + if cycle_data.status == "OPEN" and cycle_data.end_date is not None: + return False, "end_date must be empty when status is OPEN" + return True, "" + + +def update_cycle_service(db_session: Session, user_id: int, cycle_data: CycleUpdate) -> CycleRead: + is_valid, err_msg = _validate_cycle_update_data(cycle_data) + if not is_valid: + raise InvalidCycleDataError(err_msg) + cycle_id = cast("int", cycle_data.id) + existing_cycle = crud.get_cycle_by_id(db_session, cycle_id) + if not existing_cycle: + raise CycleNotFoundError("Cycle not found") + if existing_cycle.user_id != user_id: + raise CycleNotFoundError("Cycle not found") + + provided_data_dict = cycle_data.model_dump(exclude_unset=True) + cycle_data_with_user_id: CycleBase = CycleBase.model_validate(provided_data_dict) + + try: + updated_cycle = crud.update_cycle(db_session, cycle_id, update_data=cycle_data_with_user_id) + except Exception as e: + logger.exception("Failed to update cycle: \n") + raise ServiceError("Failed to update cycle") from e + return CycleRead.model_validate(updated_cycle) + + +# Trades service +def _append_cashflows(trade_data: TradeCreate) -> TradeCreate: + sign_multipler: int + if trade_data.trade_type in ("SELL_PUT", "SELL_CALL", "EXERCISE_CALL", "CLOSE_LONG_SPOT", "SHORT_SPOT"): + sign_multipler = 1 + else: + sign_multipler = -1 + quantity = trade_data.quantity * trade_data.quantity_multiplier + gross_cash_flow_cents = quantity * trade_data.price_cents * sign_multipler + net_cash_flow_cents = gross_cash_flow_cents - trade_data.commission_cents + trade_data.gross_cash_flow_cents = gross_cash_flow_cents + trade_data.net_cash_flow_cents = net_cash_flow_cents + return trade_data + + +def _validate_trade_data(trade_data: TradeCreate) -> bool: + return not ( + trade_data.trade_type in ("SELL_PUT", "SELL_CALL") and (trade_data.expiry_date is None or trade_data.strike_price_cents is None) + ) + + +def create_trade_service(db_session: Session, user_id: int, trade_data: TradeCreate) -> TradeRead: + if not _validate_trade_data(trade_data): + raise InvalidTradeDataError("Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades") + trade_data_dict = trade_data.model_dump() + trade_data_dict["user_id"] = user_id + trade_data_with_user_id: TradeCreate = TradeCreate.model_validate(trade_data_dict) + trade_data_with_user_id = _append_cashflows(trade_data_with_user_id) + created_trade = crud.create_trade(db_session, trade_data=trade_data_with_user_id) + return TradeRead.model_validate(created_trade) + + +def get_trade_by_id_service(db_session: Session, user_id: int, trade_id: int) -> TradeRead: + trade = crud.get_trade_by_id(db_session, trade_id) + if not trade: + raise TradeNotFoundError("Trade not found") + if trade.user_id != user_id: + raise TradeNotFoundError("Trade not found") + return TradeRead.model_validate(trade) + + +def update_trade_friendly_name_service(db_session: Session, user_id: int, trade_id: int, friendly_name: str) -> TradeRead: + existing_trade = crud.get_trade_by_id(db_session, trade_id) + if not existing_trade: + raise TradeNotFoundError("Trade not found") + if existing_trade.user_id != user_id: + raise TradeNotFoundError("Trade not found") + try: + updated_trade = crud.update_trade_friendly_name(db_session, trade_id, friendly_name) + except Exception as e: + logger.exception("Failed to update trade friendly name: \n") + raise ServiceError("Failed to update trade friendly name") from e + return TradeRead.model_validate(updated_trade) + + +def update_trade_note_service(db_session: Session, user_id: int, trade_id: int, note: str | None) -> TradeRead: + existing_trade = crud.get_trade_by_id(db_session, trade_id) + if not existing_trade: + raise TradeNotFoundError("Trade not found") + if existing_trade.user_id != user_id: + raise TradeNotFoundError("Trade not found") + if note is None: + note = "" + try: + updated_trade = crud.update_trade_note(db_session, trade_id, note) + except Exception as e: + logger.exception("Failed to update trade notes: \n") + raise ServiceError("Failed to update trade notes") from e + return TradeRead.model_validate(updated_trade) -- 2.49.1 From e66aab99ea852e46d15ce1002bc0975728420e16 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 24 Sep 2025 21:02:21 +0200 Subject: [PATCH 13/18] basic api is there --- backend/testhelpers/tradecycles.sh | 56 ++++++++++++++++++++++++++++++ backend/trading_journal/dto.py | 2 +- backend/trading_journal/service.py | 14 +++++++- 3 files changed, 70 insertions(+), 2 deletions(-) create mode 100755 backend/testhelpers/tradecycles.sh diff --git a/backend/testhelpers/tradecycles.sh b/backend/testhelpers/tradecycles.sh new file mode 100755 index 0000000..f2f5831 --- /dev/null +++ b/backend/testhelpers/tradecycles.sh @@ -0,0 +1,56 @@ +curl --location '127.0.0.1:18881/api/v1/trades' \ +--header 'Content-Type: application/json' \ +--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \ +--data '{ + "friendly_name": "20250908-CA-PUT", + "symbol": "CA", + "exchange_id": 1, + "underlying_currency": "EUR", + "trade_type": "SELL_PUT", + "trade_strategy": "WHEEL", + "trade_date": "2025-09-08", + "quantity": 1, + "quantity_multiplier": 100, + "price_cents": 17, + "expiry_date": "2025-09-09", + "strike_price_cents": 1220, + "commission_cents": 114 +}' + +curl --location '127.0.0.1:18881/api/v1/trades' \ +--header 'Content-Type: application/json' \ +--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \ +--data '{ + "friendly_name": "20250920-CA-ASSIGN", + "symbol": "CA", + "exchange_id": 1, + "cycle_id": 1, + "underlying_currency": "EUR", + "trade_type": "ASSIGNMENT", + "trade_strategy": "WHEEL", + "trade_date": "2025-09-20", + "quantity": 100, + "quantity_multiplier": 1, + "price_cents": 1220, + "commission_cents": 0 +}' + +curl --location '127.0.0.1:18881/api/v1/trades' \ +--header 'Content-Type: application/json' \ +--header 'Cookie: session_token=uYsEZZdH9ecQ432HQUdfab292I14suk4GuI12-cAyuw' \ +--data '{ + "friendly_name": "20250923-CA-CALL", + "symbol": "CA", + "exchange_id": 1, + "cycle_id": 1, + "underlying_currency": "EUR", + "trade_type": "SELL_CALL", + "trade_strategy": "WHEEL", + "trade_date": "2025-09-23", + "quantity": 1, + "quantity_multiplier": 100, + "price_cents": 31, + "expiry_date": "2025-10-10", + "strike_price_cents": 1200, + "commission_cents": 114 +}' \ No newline at end of file diff --git a/backend/trading_journal/dto.py b/backend/trading_journal/dto.py index 1851b0d..0dd227f 100644 --- a/backend/trading_journal/dto.py +++ b/backend/trading_journal/dto.py @@ -69,7 +69,7 @@ class CycleBase(SQLModel): funding_source: str | None = None capital_exposure_cents: int | None = None loan_amount_cents: int | None = None - loan_interest_rate_bps: int | None = None + loan_interest_rate_tenth_bps: int | None = None trades: list[TradeRead] | None = None exchange: ExchangesRead | None = None diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 07a3b19..4855c3e 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -250,11 +250,23 @@ def get_cycles_by_user_service(db_session: Session, user_id: int) -> list[CycleR return [CycleRead.model_validate(cycle) for cycle in cycles] -def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: +def _validate_cycle_update_data(cycle_data: CycleUpdate) -> tuple[bool, str]: # noqa: PLR0911 if cycle_data.status == "CLOSED" and cycle_data.end_date is None: return False, "end_date is required when status is CLOSED" if cycle_data.status == "OPEN" and cycle_data.end_date is not None: return False, "end_date must be empty when status is OPEN" + if cycle_data.capital_exposure_cents is not None and cycle_data.capital_exposure_cents < 0: + return False, "capital_exposure_cents must be non-negative" + if ( + cycle_data.funding_source is not None + and cycle_data.funding_source != "CASH" + and (cycle_data.loan_amount_cents is None or cycle_data.loan_interest_rate_tenth_bps is None) + ): + return False, "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + if cycle_data.loan_amount_cents is not None and cycle_data.loan_amount_cents < 0: + return False, "loan_amount_cents must be non-negative" + if cycle_data.loan_interest_rate_tenth_bps is not None and cycle_data.loan_interest_rate_tenth_bps < 0: + return False, "loan_interest_rate_tenth_bps must be non-negative" return True, "" -- 2.49.1 From 27b4adaca4cb40995e9fe0bf3f2386c91c5585c0 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 25 Sep 2025 12:08:07 +0200 Subject: [PATCH 14/18] add interest change tables --- backend/tests/test_db_migration.py | 24 +++++++++++++++ backend/tests/test_service.py | 5 ++++ backend/trading_journal/db_migration.py | 2 ++ backend/trading_journal/models.py | 40 +++++++++++++++++++++++++ backend/trading_journal/models_v1.py | 40 +++++++++++++++++++++++++ 5 files changed, 111 insertions(+) create mode 100644 backend/tests/test_service.py diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index 042bb54..f14b2da 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -46,6 +46,23 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "start_date": ("DATE", 1, 0), "end_date": ("DATE", 0, 0), }, + "cycle_loan_change_events": { + "id": ("INTEGER", 1, 1), + "cycle_id": ("INTEGER", 1, 0), + "effective_date": ("DATE", 1, 0), + "loan_amount_cents": ("INTEGER", 0, 0), + "loan_interest_rate_tenth_bps": ("INTEGER", 0, 0), + "related_trade_id": ("INTEGER", 0, 0), + "notes": ("TEXT", 0, 0), + "created_at": ("DATETIME", 1, 0), + }, + "cycle_daily_accrual": { + "id": ("INTEGER", 1, 1), + "cycle_id": ("INTEGER", 1, 0), + "accrual_date": ("DATE", 1, 0), + "accrual_amount_cents": ("INTEGER", 1, 0), + "created_at": ("DATETIME", 1, 0), + }, "trades": { "id": ("INTEGER", 1, 1), "user_id": ("INTEGER", 1, 0), @@ -100,6 +117,13 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: {"table": "users", "from": "user_id", "to": "id"}, {"table": "exchanges", "from": "exchange_id", "to": "id"}, ], + "cycle_loan_change_events": [ + {"table": "cycles", "from": "cycle_id", "to": "id"}, + {"table": "trades", "from": "related_trade_id", "to": "id"}, + ], + "cycle_daily_accrual": [ + {"table": "cycles", "from": "cycle_id", "to": "id"}, + ], "sessions": [ {"table": "users", "from": "user_id", "to": "id"}, ], diff --git a/backend/tests/test_service.py b/backend/tests/test_service.py new file mode 100644 index 0000000..b23e2be --- /dev/null +++ b/backend/tests/test_service.py @@ -0,0 +1,5 @@ +import pytest + +from trading_journal import crud, service + +monkeypatch = pytest.MonkeyPatch() diff --git a/backend/trading_journal/db_migration.py b/backend/trading_journal/db_migration.py index 8a63250..b6a78ea 100644 --- a/backend/trading_journal/db_migration.py +++ b/backend/trading_journal/db_migration.py @@ -28,6 +28,8 @@ def _mig_0_1(engine: Engine) -> None: models_v1.Users.__table__, # type: ignore[attr-defined] models_v1.Sessions.__table__, # type: ignore[attr-defined] models_v1.Exchanges.__table__, # type: ignore[attr-defined] + models_v1.CycleLoanChangeEvents.__table__, # type: ignore[attr-defined] + models_v1.CycleDailyAccrual.__table__, # type: ignore[attr-defined] ], ) diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index e8dc281..81d98e1 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -1,11 +1,13 @@ from datetime import date, datetime from enum import Enum +from typing import Optional from sqlmodel import ( Column, Date, DateTime, Field, + ForeignKey, Integer, Relationship, SQLModel, @@ -92,8 +94,14 @@ class Trades(SQLModel, table=True): replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True) + cycle: "Cycles" = Relationship(back_populates="trades") + related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship( + back_populates="trade", + sa_relationship_kwargs={"uselist": False}, + ) + class Cycles(SQLModel, table=True): __tablename__ = "cycles" # type: ignore[attr-defined] @@ -113,8 +121,40 @@ class Cycles(SQLModel, table=True): loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) + trades: list["Trades"] = Relationship(back_populates="cycle") + loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle") + daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle") + + +class CycleLoanChangeEvents(SQLModel, table=True): + __tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined] + id: int | None = Field(default=None, primary_key=True) + cycle_id: int = Field(foreign_key="cycles.id", nullable=False, index=True) + effective_date: date = Field(sa_column=Column(Date, nullable=False)) + loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) + loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) + related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True)) + notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + + cycle: "Cycles" = Relationship(back_populates="loan_change_events") + trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event") + + +class CycleDailyAccrual(SQLModel, table=True): + __tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined] + __table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),) + + id: int | None = Field(default=None, primary_key=True) + cycle_id: int = Field(foreign_key="cycles.id", nullable=False, index=True) + accrual_date: date = Field(sa_column=Column(Date, nullable=False)) + accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False)) + created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + + cycle: "Cycles" = Relationship(back_populates="daily_accruals") + class Exchanges(SQLModel, table=True): __tablename__ = "exchanges" # type: ignore[attr-defined] diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index e8dc281..81d98e1 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -1,11 +1,13 @@ from datetime import date, datetime from enum import Enum +from typing import Optional from sqlmodel import ( Column, Date, DateTime, Field, + ForeignKey, Integer, Relationship, SQLModel, @@ -92,8 +94,14 @@ class Trades(SQLModel, table=True): replaced_by_trade_id: int | None = Field(default=None, foreign_key="trades.id", nullable=True) notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) cycle_id: int | None = Field(default=None, foreign_key="cycles.id", nullable=True, index=True) + cycle: "Cycles" = Relationship(back_populates="trades") + related_loan_change_event: Optional["CycleLoanChangeEvents"] = Relationship( + back_populates="trade", + sa_relationship_kwargs={"uselist": False}, + ) + class Cycles(SQLModel, table=True): __tablename__ = "cycles" # type: ignore[attr-defined] @@ -113,8 +121,40 @@ class Cycles(SQLModel, table=True): loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) + trades: list["Trades"] = Relationship(back_populates="cycle") + loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle") + daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle") + + +class CycleLoanChangeEvents(SQLModel, table=True): + __tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined] + id: int | None = Field(default=None, primary_key=True) + cycle_id: int = Field(foreign_key="cycles.id", nullable=False, index=True) + effective_date: date = Field(sa_column=Column(Date, nullable=False)) + loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) + loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) + related_trade_id: int | None = Field(default=None, sa_column=Column(Integer, ForeignKey("trades.id"), nullable=True, unique=True)) + notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) + created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + + cycle: "Cycles" = Relationship(back_populates="loan_change_events") + trade: Optional["Trades"] = Relationship(back_populates="related_loan_change_event") + + +class CycleDailyAccrual(SQLModel, table=True): + __tablename__ = "cycle_daily_accrual" # type: ignore[attr-defined] + __table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),) + + id: int | None = Field(default=None, primary_key=True) + cycle_id: int = Field(foreign_key="cycles.id", nullable=False, index=True) + accrual_date: date = Field(sa_column=Column(Date, nullable=False)) + accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False)) + created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + + cycle: "Cycles" = Relationship(back_populates="daily_accruals") + class Exchanges(SQLModel, table=True): __tablename__ = "exchanges" # type: ignore[attr-defined] -- 2.49.1 From 6a5f160d83b9ba39411c2ff27c87d1f94699446c Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 25 Sep 2025 22:16:24 +0200 Subject: [PATCH 15/18] add interest accural test, improve migration tests --- backend/tests/test_crud.py | 221 +++++++++++++++++++++++++++ backend/tests/test_db_migration.py | 35 +++++ backend/trading_journal/crud.py | 91 ++++++++++- backend/trading_journal/models.py | 14 +- backend/trading_journal/models_v1.py | 14 +- 5 files changed, 366 insertions(+), 9 deletions(-) diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index d7ae484..9f343f6 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -137,6 +137,16 @@ def _ensure_utc_aware(dt: datetime | None) -> datetime | None: return dt.astimezone(timezone.utc) +def _validate_timestamp(actual: datetime, expected: datetime, tolerance: timedelta) -> None: + actual_utc = _ensure_utc_aware(actual) + expected_utc = _ensure_utc_aware(expected) + assert actual_utc is not None + assert expected_utc is not None + delta = abs(actual_utc - expected_utc) + assert delta <= tolerance, f"Timestamps differ by {delta}, which exceeds tolerance of {tolerance}" + + +# Trades def test_create_trade_success_with_cycle(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) @@ -554,6 +564,7 @@ def test_replace_trade(session: Session) -> None: assert actual_new_trade.replaced_by_trade_id == old_trade_id +# Cycles def test_create_cycle(session: Session) -> None: user_id = make_user(session) exchange_id = make_exchange(session, user_id) @@ -656,6 +667,216 @@ def test_update_cycle_immutable_fields(session: Session) -> None: ) +# Cycle loans +def test_create_cycle_loan_event(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + loan_data = { + "cycle_id": cycle_id, + "loan_amount_cents": 100000, + "loan_interest_rate_tenth_bps": 5000, # 5% + "notes": "Test loan change for the cycle", + } + + loan_event = crud.create_cycle_loan_event(session, loan_data) + now = datetime.now(timezone.utc) + assert loan_event.id is not None + assert loan_event.cycle_id == cycle_id + assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"] + assert loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"] + assert loan_event.notes == loan_data["notes"] + assert loan_event.effective_date == now.date() + _validate_timestamp(loan_event.created_at, now, timedelta(seconds=1)) + + session.refresh(loan_event) + actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id) + assert actual_loan_event is not None + assert actual_loan_event.cycle_id == cycle_id + assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"] + assert actual_loan_event.loan_interest_rate_tenth_bps == loan_data["loan_interest_rate_tenth_bps"] + assert actual_loan_event.notes == loan_data["notes"] + assert actual_loan_event.effective_date == now.date() + _validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1)) + + +def test_get_cycle_loan_events_by_cycle_id(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + loan_data_1 = { + "cycle_id": cycle_id, + "loan_amount_cents": 100000, + "loan_interest_rate_tenth_bps": 5000, + "notes": "First loan event", + } + yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date() + loan_data_2 = { + "cycle_id": cycle_id, + "loan_amount_cents": 150000, + "loan_interest_rate_tenth_bps": 4500, + "effective_date": yesterday, + "notes": "Second loan event", + } + + crud.create_cycle_loan_event(session, loan_data_1) + crud.create_cycle_loan_event(session, loan_data_2) + + loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id) + assert len(loan_events) == 2 + notes = [event.notes for event in loan_events] + assert loan_events[0].notes == loan_data_2["notes"] + assert loan_events[0].effective_date == yesterday + assert notes == ["Second loan event", "First loan event"] # Ordered by effective_date desc + + +def test_get_cycle_loan_events_by_cycle_id_same_date(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + loan_data_1 = { + "cycle_id": cycle_id, + "loan_amount_cents": 100000, + "loan_interest_rate_tenth_bps": 5000, + "notes": "First loan event", + } + loan_data_2 = { + "cycle_id": cycle_id, + "loan_amount_cents": 150000, + "loan_interest_rate_tenth_bps": 4500, + "notes": "Second loan event", + } + + crud.create_cycle_loan_event(session, loan_data_1) + crud.create_cycle_loan_event(session, loan_data_2) + + loan_events = crud.get_loan_events_by_cycle_id(session, cycle_id) + assert len(loan_events) == 2 + notes = [event.notes for event in loan_events] + assert notes == ["First loan event", "Second loan event"] # Ordered by id desc when effective_date is same + + +def test_create_cycle_loan_event_single_field(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + loan_data = { + "cycle_id": cycle_id, + "loan_amount_cents": 200000, + } + + loan_event = crud.create_cycle_loan_event(session, loan_data) + now = datetime.now(timezone.utc) + assert loan_event.id is not None + assert loan_event.cycle_id == cycle_id + assert loan_event.loan_amount_cents == loan_data["loan_amount_cents"] + assert loan_event.loan_interest_rate_tenth_bps is None + assert loan_event.notes is None + assert loan_event.effective_date == now.date() + _validate_timestamp(loan_event.created_at, now, timedelta(seconds=1)) + + session.refresh(loan_event) + actual_loan_event = session.get(models.CycleLoanChangeEvents, loan_event.id) + assert actual_loan_event is not None + assert actual_loan_event.cycle_id == cycle_id + assert actual_loan_event.loan_amount_cents == loan_data["loan_amount_cents"] + assert actual_loan_event.loan_interest_rate_tenth_bps is None + assert actual_loan_event.notes is None + assert actual_loan_event.effective_date == now.date() + _validate_timestamp(actual_loan_event.created_at, now, timedelta(seconds=1)) + + +def test_create_cycle_daily_accrual(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + today = datetime.now(timezone.utc).date() + accrual_data = { + "cycle_id": cycle_id, + "accrual_date": today, + "accrued_interest_cents": 150, + "notes": "Daily interest accrual", + } + + accrual = crud.create_cycle_daily_accrual(session, cycle_id, accrual_data["accrual_date"], accrual_data["accrued_interest_cents"]) + assert accrual.id is not None + assert accrual.cycle_id == cycle_id + assert accrual.accrual_date == accrual_data["accrual_date"] + assert accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"] + + session.refresh(accrual) + actual_accrual = session.get(models.CycleDailyAccrual, accrual.id) + assert actual_accrual is not None + assert actual_accrual.cycle_id == cycle_id + assert actual_accrual.accrual_date == accrual_data["accrual_date"] + assert actual_accrual.accrual_amount_cents == accrual_data["accrued_interest_cents"] + + +def test_get_cycle_daily_accruals_by_cycle_id(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + today = datetime.now(timezone.utc).date() + yesterday = today - timedelta(days=1) + + accrual_data_1 = { + "cycle_id": cycle_id, + "accrual_date": yesterday, + "accrued_interest_cents": 100, + } + accrual_data_2 = { + "cycle_id": cycle_id, + "accrual_date": today, + "accrued_interest_cents": 150, + } + + crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"]) + crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"]) + + accruals = crud.get_cycle_daily_accruals_by_cycle_id(session, cycle_id) + assert len(accruals) == 2 + dates = [accrual.accrual_date for accrual in accruals] + assert dates == [yesterday, today] # Ordered by accrual_date asc + + +def test_get_cycle_daily_accruals_by_cycle_id_and_date(session: Session) -> None: + user_id = make_user(session) + exchange_id = make_exchange(session, user_id) + cycle_id = make_cycle(session, user_id, exchange_id) + + today = datetime.now(timezone.utc).date() + yesterday = today - timedelta(days=1) + + accrual_data_1 = { + "cycle_id": cycle_id, + "accrual_date": yesterday, + "accrued_interest_cents": 100, + } + accrual_data_2 = { + "cycle_id": cycle_id, + "accrual_date": today, + "accrued_interest_cents": 150, + } + + crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_1["accrual_date"], accrual_data_1["accrued_interest_cents"]) + crud.create_cycle_daily_accrual(session, cycle_id, accrual_data_2["accrual_date"], accrual_data_2["accrued_interest_cents"]) + + accruals_today = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, today) + assert accruals_today is not None + assert accruals_today.accrual_date == today + assert accruals_today.accrual_amount_cents == accrual_data_2["accrued_interest_cents"] + + accruals_yesterday = crud.get_cycle_daily_accrual_by_cycle_id_and_date(session, cycle_id, yesterday) + assert accruals_yesterday is not None + assert accruals_yesterday.accrual_date == yesterday + assert accruals_yesterday.accrual_amount_cents == accrual_data_1["accrued_interest_cents"] + + # Exchanges def test_create_exchange(session: Session) -> None: user_id = make_user(session) diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index f14b2da..69fd63a 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -45,6 +45,8 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "loan_interest_rate_tenth_bps": ("INTEGER", 0, 0), "start_date": ("DATE", 1, 0), "end_date": ("DATE", 0, 0), + "latest_interest_accrued_date": ("DATE", 0, 0), + "total_accrued_amount_cents": ("INTEGER", 1, 0), }, "cycle_loan_change_events": { "id": ("INTEGER", 1, 1), @@ -170,6 +172,39 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows] for efk in fks: assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}" + + # check trades.replaced_by_trade_id self-referential FK + fk_rows = conn.execute(text("PRAGMA foreign_key_list('trades')")).fetchall() + actual_fk_list = [{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows] + assert {"table": "trades", "from": "replaced_by_trade_id", "to": "id"} in actual_fk_list, ( + "missing self FK trades.replaced_by_trade_id -> trades.id" + ) + + # helper to find unique index on a column + def has_unique_index(table: str, column: str) -> bool: + idx_rows = conn.execute(text(f"PRAGMA index_list('{table}')")).fetchall() + for idx in idx_rows: + idx_name = idx[1] + is_unique = bool(idx[2]) + if not is_unique: + continue + info = conn.execute(text(f"PRAGMA index_info('{idx_name}')")).fetchall() + cols = [r[2] for r in info] + if column in cols: + return True + return False + + assert has_unique_index("trades", "friendly_name"), ( + "expected unique index on trades(friendly_name) per uq_trades_user_friendly_name" + ) + assert has_unique_index("cycles", "friendly_name"), ( + "expected unique index on cycles(friendly_name) per uq_cycles_user_friendly_name" + ) + assert has_unique_index("exchanges", "name"), "expected unique index on exchanges(name) per uq_exchanges_user_name" + assert has_unique_index("sessions", "session_token_hash"), "expected unique index on sessions(session_token_hash)" + assert has_unique_index("cycle_loan_change_events", "related_trade_id"), ( + "expected unique index on cycle_loan_change_events(related_trade_id)" + ) finally: engine.dispose() SQLModel.metadata.clear() diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index da8e237..66e2c4f 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import date, datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, TypeVar, cast from pydantic import BaseModel @@ -13,6 +13,8 @@ if TYPE_CHECKING: from collections.abc import Mapping from enum import Enum + from sqlalchemy.sql.elements import ColumnElement + # Generic enum member type T = TypeVar("T", bound="Enum") @@ -301,6 +303,93 @@ def update_cycle(session: Session, cycle_id: int, update_data: Mapping[str, Any] return cycle +# Cycle loan and interest +def create_cycle_loan_event(session: Session, loan_data: Mapping[str, Any] | BaseModel) -> models.CycleLoanChangeEvents: + data = _data_to_dict(loan_data) + allowed = _allowed_columns(models.CycleLoanChangeEvents) + payload = {k: v for k, v in data.items() if k in allowed} + if "cycle_id" not in payload: + raise ValueError("cycle_id is required") + cycle = session.get(models.Cycles, payload["cycle_id"]) + if cycle is None: + raise ValueError("cycle_id does not exist") + + payload["effective_date"] = payload.get("effective_date") or datetime.now(timezone.utc).date() + payload["created_at"] = datetime.now(timezone.utc) + cle = models.CycleLoanChangeEvents(**payload) + session.add(cle) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("create_cycle_loan_event integrity error") from e + session.refresh(cle) + return cle + + +def get_loan_events_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleLoanChangeEvents]: + eff_col = cast("ColumnElement", models.CycleLoanChangeEvents.effective_date) + id_col = cast("ColumnElement", models.CycleLoanChangeEvents.id) + statement = ( + select(models.CycleLoanChangeEvents) + .where( + models.CycleLoanChangeEvents.cycle_id == cycle_id, + ) + .order_by(eff_col, id_col.asc()) + ) + return list(session.exec(statement).all()) + + +def create_cycle_daily_accrual(session: Session, cycle_id: int, accrual_date: date, accrual_amount_cents: int) -> models.CycleDailyAccrual: + cycle = session.get(models.Cycles, cycle_id) + if cycle is None: + raise ValueError("cycle_id does not exist") + existing = session.exec( + select(models.CycleDailyAccrual).where( + models.CycleDailyAccrual.cycle_id == cycle_id, + models.CycleDailyAccrual.accrual_date == accrual_date, + ), + ).first() + if existing: + return existing + if accrual_amount_cents < 0: + raise ValueError("accrual_amount_cents must be non-negative") + row = models.CycleDailyAccrual( + cycle_id=cycle_id, + accrual_date=accrual_date, + accrual_amount_cents=accrual_amount_cents, + created_at=datetime.now(timezone.utc), + ) + session.add(row) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("create_cycle_daily_accrual integrity error") from e + session.refresh(row) + return row + + +def get_cycle_daily_accruals_by_cycle_id(session: Session, cycle_id: int) -> list[models.CycleDailyAccrual]: + date_col = cast("ColumnElement", models.CycleDailyAccrual.accrual_date) + statement = ( + select(models.CycleDailyAccrual) + .where( + models.CycleDailyAccrual.cycle_id == cycle_id, + ) + .order_by(date_col.asc()) + ) + return list(session.exec(statement).all()) + + +def get_cycle_daily_accrual_by_cycle_id_and_date(session: Session, cycle_id: int, accrual_date: date) -> models.CycleDailyAccrual | None: + statement = select(models.CycleDailyAccrual).where( + models.CycleDailyAccrual.cycle_id == cycle_id, + models.CycleDailyAccrual.accrual_date == accrual_date, + ) + return session.exec(statement).first() + + # Exchanges IMMUTABLE_EXCHANGE_FIELDS = {"id"} diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 81d98e1..f060c26 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -18,8 +18,10 @@ from sqlmodel import ( class TradeType(str, Enum): SELL_PUT = "SELL_PUT" + CLOSE_SELL_PUT = "CLOSE_SELL_PUT" ASSIGNMENT = "ASSIGNMENT" SELL_CALL = "SELL_CALL" + CLOSE_SELL_CALL = "CLOSE_SELL_CALL" EXERCISE_CALL = "EXERCISE_CALL" LONG_SPOT = "LONG_SPOT" CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT" @@ -117,13 +119,17 @@ class Cycles(SQLModel, table=True): status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) - loan_amount_cents: int | None = Field(default=None, nullable=True) - loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) trades: list["Trades"] = Relationship(back_populates="cycle") + loan_amount_cents: int | None = Field(default=None, nullable=True) + loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) + + latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) + total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False)) + loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle") daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle") @@ -131,7 +137,7 @@ class Cycles(SQLModel, table=True): class CycleLoanChangeEvents(SQLModel, table=True): __tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined] id: int | None = Field(default=None, primary_key=True) - cycle_id: int = Field(foreign_key="cycles.id", nullable=False, index=True) + cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True)) effective_date: date = Field(sa_column=Column(Date, nullable=False)) loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) @@ -148,7 +154,7 @@ class CycleDailyAccrual(SQLModel, table=True): __table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),) id: int | None = Field(default=None, primary_key=True) - cycle_id: int = Field(foreign_key="cycles.id", nullable=False, index=True) + cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True)) accrual_date: date = Field(sa_column=Column(Date, nullable=False)) accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False)) created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 81d98e1..f060c26 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -18,8 +18,10 @@ from sqlmodel import ( class TradeType(str, Enum): SELL_PUT = "SELL_PUT" + CLOSE_SELL_PUT = "CLOSE_SELL_PUT" ASSIGNMENT = "ASSIGNMENT" SELL_CALL = "SELL_CALL" + CLOSE_SELL_CALL = "CLOSE_SELL_CALL" EXERCISE_CALL = "EXERCISE_CALL" LONG_SPOT = "LONG_SPOT" CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT" @@ -117,13 +119,17 @@ class Cycles(SQLModel, table=True): status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) - loan_amount_cents: int | None = Field(default=None, nullable=True) - loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) trades: list["Trades"] = Relationship(back_populates="cycle") + loan_amount_cents: int | None = Field(default=None, nullable=True) + loan_interest_rate_tenth_bps: int | None = Field(default=None, nullable=True) + + latest_interest_accrued_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) + total_accrued_amount_cents: int = Field(default=0, sa_column=Column(Integer, nullable=False)) + loan_change_events: list["CycleLoanChangeEvents"] = Relationship(back_populates="cycle") daily_accruals: list["CycleDailyAccrual"] = Relationship(back_populates="cycle") @@ -131,7 +137,7 @@ class Cycles(SQLModel, table=True): class CycleLoanChangeEvents(SQLModel, table=True): __tablename__ = "cycle_loan_change_events" # type: ignore[attr-defined] id: int | None = Field(default=None, primary_key=True) - cycle_id: int = Field(foreign_key="cycles.id", nullable=False, index=True) + cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True)) effective_date: date = Field(sa_column=Column(Date, nullable=False)) loan_amount_cents: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) loan_interest_rate_tenth_bps: int | None = Field(default=None, sa_column=Column(Integer, nullable=True)) @@ -148,7 +154,7 @@ class CycleDailyAccrual(SQLModel, table=True): __table_args__ = (UniqueConstraint("cycle_id", "accrual_date", name="uq_cycle_daily_accruals_cycle_date"),) id: int | None = Field(default=None, primary_key=True) - cycle_id: int = Field(foreign_key="cycles.id", nullable=False, index=True) + cycle_id: int = Field(sa_column=Column(Integer, ForeignKey("cycles.id", ondelete="CASCADE"), nullable=False, index=True)) accrual_date: date = Field(sa_column=Column(Date, nullable=False)) accrual_amount_cents: int = Field(sa_column=Column(Integer, nullable=False)) created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) -- 2.49.1 From 5eae75b23ecfe0c3f9329163e4ea0c7e47d31151 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 26 Sep 2025 22:37:26 +0200 Subject: [PATCH 16/18] wip service test --- backend/tests/test_service.py | 252 +++++++++++++++++++++++++++++++++- 1 file changed, 250 insertions(+), 2 deletions(-) diff --git a/backend/tests/test_service.py b/backend/tests/test_service.py index b23e2be..08974b1 100644 --- a/backend/tests/test_service.py +++ b/backend/tests/test_service.py @@ -1,5 +1,253 @@ +import asyncio +import json +from collections.abc import Generator +from contextlib import contextmanager +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import ANY, patch + import pytest +from fastapi import FastAPI, status +from fastapi.requests import Request +from fastapi.responses import Response -from trading_journal import crud, service +from settings import settings +from trading_journal import service -monkeypatch = pytest.MonkeyPatch() + +# --- Auth middleware --------------------------------------------------------- +class FakeDBFactory: + @contextmanager + def get_session_ctx_manager(self) -> Generator[SimpleNamespace, None, None]: + yield SimpleNamespace(name="fakesession") + + +def verify_json_response(response: Response, expected_status: int, expected_detail: str) -> None: + assert response.status_code == expected_status + body_bytes = response.body.tobytes() if isinstance(response.body, memoryview) else response.body + body_text = body_bytes.decode("utf-8") + body_json = json.loads(body_text) + assert body_json.get("detail") == expected_detail + + +def test_auth_middleware_allows_public_path() -> None: + app = FastAPI() + middleware = service.AuthMiddleWare(app) + + for p in service.EXCEPT_PATHS: + scope = { + "type": "http", + "method": "GET", + "path": p, + "headers": [], + "client": ("testclient", 50000), + } + request = Request(scope) + + async def call_next(req: Request, expected: Request = request) -> Response: + assert req is expected + return Response(status_code=status.HTTP_204_NO_CONTENT) + + response = asyncio.run(middleware.dispatch(request, call_next)) + assert response.status_code == status.HTTP_204_NO_CONTENT + + +def test_auth_middleware_rejects_missing_token() -> None: + app = FastAPI() + middleware = service.AuthMiddleWare(app) + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [], + "client": ("testclient", 50000), + } + request = Request(scope) + + async def call_next(req: Request) -> Response: # noqa: ARG001 + pytest.fail("call_next should not be called for missing token") + + response = asyncio.run(middleware.dispatch(request, call_next)) + verify_json_response(response, status.HTTP_401_UNAUTHORIZED, "Unauthorized") + + +def test_auth_middleware_no_db() -> None: + app = FastAPI() + middleware = service.AuthMiddleWare(app) + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [(b"authorization", b"Bearer invalidtoken")], + "client": ("testclient", 50000), + "app": app, + } + request = Request(scope) + + async def call_next(req: Request) -> Response: # noqa: ARG001 + pytest.fail("call_next should not be called for invalid token") + + response = asyncio.run(middleware.dispatch(request, call_next)) + verify_json_response(response, status.HTTP_500_INTERNAL_SERVER_ERROR, "db factory not configured") + + +def test_auth_middleware_rejects_invalid_token() -> None: + app = FastAPI() + app.state.db_factory = FakeDBFactory() + middleware = service.AuthMiddleWare(app) + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [(b"authorization", b"Bearer invalidtoken")], + "client": ("testclient", 50000), + "app": app, + } + request = Request(scope) + + async def call_next(req: Request) -> Response: # noqa: ARG001 + pytest.fail("call_next should not be called for invalid token") + + with patch("trading_journal.crud.get_login_session_by_token_hash", return_value=None): + response = asyncio.run(middleware.dispatch(request, call_next)) + verify_json_response(response, status.HTTP_401_UNAUTHORIZED, "Unauthorized") + + +def test_auth_middleware_rejects_expired_token() -> None: + app = FastAPI() + app.state.db_factory = FakeDBFactory() + middleware = service.AuthMiddleWare(app) + fake_token_orig = "expiredtoken" + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [(b"cookie", f"session_token={fake_token_orig}".encode())], + "client": ("testclient", 50000), + "app": app, + } + request = Request(scope) + + async def call_next(req: Request) -> Response: # noqa: ARG001 + pytest.fail("call_next should not be called for expired token") + + expired_session = SimpleNamespace( + id=1, + user_id=1, + session_token_hash="expiredtokenhash", + created_at=None, + expires_at=(datetime.now(timezone.utc) - timedelta(days=1)), + ) + + with ( + patch("trading_journal.security.hash_session_token_sha256", return_value=expired_session.session_token_hash) as mock_hash, + patch("trading_journal.crud.get_login_session_by_token_hash", return_value=expired_session), + patch("trading_journal.crud.delete_login_session") as mock_delete, + ): + response = asyncio.run(middleware.dispatch(request, call_next)) + + verify_json_response(response, status.HTTP_401_UNAUTHORIZED, "Unauthorized") + mock_hash.assert_called_once_with(fake_token_orig) + mock_delete.assert_called_once_with(ANY, expired_session.session_token_hash) + + +def test_auth_middleware_reject_inactive_user() -> None: + app = FastAPI() + app.state.db_factory = FakeDBFactory() + middleware = service.AuthMiddleWare(app) + fake_token_orig = "validtoken" + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [(b"cookie", f"session_token={fake_token_orig}".encode())], + "client": ("testclient", 50000), + "app": app, + } + request = Request(scope) + + async def call_next(req: Request) -> Response: # noqa: ARG001 + pytest.fail("call_next should not be called for inactive user") + + inactive_user = SimpleNamespace( + id=1, + username="inactiveuser", + is_active=False, + ) + valid_session = SimpleNamespace( + id=1, + user_id=1, + session_token_hash="validtokenhash", + created_at=None, + expires_at=(datetime.now(timezone.utc) + timedelta(days=1)), + user=inactive_user, + ) + + with ( + patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash) as mock_hash, + patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), + ): + response = asyncio.run(middleware.dispatch(request, call_next)) + + verify_json_response(response, status.HTTP_401_UNAUTHORIZED, "Unauthorized") + + +# --- User services ----------------------------------------------------------- +def test_register_user_success(): + pytest.fail("TODO: mock crud/security, assert UserRead username") + + +def test_register_user_exists_raises(): + pytest.fail("TODO: mock get_user_by_username to return obj and expect UserAlreadyExistsError") + + +def test_authenticate_user_success(): + pytest.fail("TODO: mock crud/security, expect token + SessionsCreate DTO") + + +def test_authenticate_user_invalid_password_returns_none(): + pytest.fail("TODO: mock verify_password False") + + +# --- Exchange services ------------------------------------------------------- +def test_create_exchange_duplicate_raises(): + pytest.fail("TODO: mock get_exchange_by_name_and_user_id and expect ExchangeAlreadyExistsError") + + +def test_update_exchange_not_found(): + pytest.fail("TODO: mock get_exchange_by_id None and expect ExchangeNotFoundError") + + +# --- Cycle services ---------------------------------------------------------- +def test_validate_cycle_update_rules(): + pytest.fail("TODO: call _validate_cycle_update_data with invalid combos") + + +def test_update_cycle_owner_mismatch_raises(): + pytest.fail("TODO: mock get_cycle_by_id owned by other user, expect CycleNotFoundError") + + +# --- Trade services ---------------------------------------------------------- +def test_create_trade_invalid_sell_requires_expiry(): + pytest.fail("TODO: build SELL_PUT without expiry/strike, expect InvalidTradeDataError") + + +def test_create_trade_appends_cashflow_and_calls_crud(): + pytest.fail("TODO: mock crud.create_trade, assert net_cash_flow_cents in result") + + +def test_get_trade_by_id_missing_raises(): + pytest.fail("TODO: mock get_trade_by_id None, expect TradeNotFoundError") + + +def test_update_trade_friendly_name_not_found(): + pytest.fail("TODO: mock get_trade_by_id None, expect TradeNotFoundError") + + +def test_update_trade_note_sets_empty_string_when_none(): + pytest.fail("TODO: mock update_trade_note to return note '', assert DTO note") -- 2.49.1 From bb87b902858c65f61faabfc08e84c98bea1956de Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 29 Sep 2025 16:48:28 +0200 Subject: [PATCH 17/18] service layer add all tests for existing code --- backend/tests/test_service.py | 910 ++++++++++++++++++++++++++++- backend/trading_journal/service.py | 1 + 2 files changed, 880 insertions(+), 31 deletions(-) diff --git a/backend/tests/test_service.py b/backend/tests/test_service.py index 08974b1..0248e18 100644 --- a/backend/tests/test_service.py +++ b/backend/tests/test_service.py @@ -4,7 +4,7 @@ from collections.abc import Generator from contextlib import contextmanager from datetime import datetime, timedelta, timezone from types import SimpleNamespace -from unittest.mock import ANY, patch +from unittest.mock import ANY, MagicMock, patch import pytest from fastapi import FastAPI, status @@ -12,14 +12,17 @@ from fastapi.requests import Request from fastapi.responses import Response from settings import settings -from trading_journal import service +from trading_journal import dto, service +from trading_journal.crud import Session # --- Auth middleware --------------------------------------------------------- class FakeDBFactory: @contextmanager - def get_session_ctx_manager(self) -> Generator[SimpleNamespace, None, None]: - yield SimpleNamespace(name="fakesession") + def get_session_ctx_manager(self) -> Generator[Session, None, None]: + fake_session = MagicMock(spec=Session) + fake_session.name = "FakeDBSession" + yield fake_session def verify_json_response(response: Response, expected_status: int, expected_detail: str) -> None: @@ -189,7 +192,7 @@ def test_auth_middleware_reject_inactive_user() -> None: ) with ( - patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash) as mock_hash, + patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), ): response = asyncio.run(middleware.dispatch(request, call_next)) @@ -197,57 +200,902 @@ def test_auth_middleware_reject_inactive_user() -> None: verify_json_response(response, status.HTTP_401_UNAUTHORIZED, "Unauthorized") +def test_auth_middleware_allows_valid_token_and_no_update_expires() -> None: + app = FastAPI() + app.state.db_factory = FakeDBFactory() + middleware = service.AuthMiddleWare(app) + fake_token_orig = "validtoken" + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [(b"cookie", f"session_token={fake_token_orig}".encode()), (b"user-agent", b"test-agent")], + "client": ("testclient", 50000), + "app": app, + } + request = Request(scope) + + async def call_next(req: Request, expected: Request = request) -> Response: + assert req is expected + assert hasattr(req.state, "user_id") + assert req.state.user_id == 1 + return Response(status_code=status.HTTP_204_NO_CONTENT) + + active_user = SimpleNamespace( + id=1, + username="activeuser", + is_active=True, + ) + valid_session = SimpleNamespace( + id=1, + user_id=1, + session_token_hash="validtokenhash", + expires_at=(datetime.now(timezone.utc) + timedelta(days=1)), + user=active_user, + ) + + with ( + patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), + patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), + patch("trading_journal.crud.update_login_session") as mock_update, + ): + response = asyncio.run(middleware.dispatch(request, call_next)) + assert response.status_code == status.HTTP_204_NO_CONTENT + mock_update.assert_called_once() + _, kwargs = mock_update.call_args + update_session = kwargs.get("update_session") + assert update_session is not None + assert update_session.expires_at == valid_session.expires_at + + +def test_auth_middleware_allows_valid_token_and_updates_expires() -> None: + app = FastAPI() + app.state.db_factory = FakeDBFactory() + middleware = service.AuthMiddleWare(app) + fake_token_orig = "validtoken" + + scope = { + "type": "http", + "method": "GET", + "path": f"/{settings.api_base}/protected", + "headers": [(b"cookie", f"session_token={fake_token_orig}".encode()), (b"user-agent", b"test-agent")], + "client": ("testclient", 50000), + "app": app, + } + request = Request(scope) + + async def call_next(req: Request, expected: Request = request) -> Response: + assert req is expected + assert hasattr(req.state, "user_id") + assert req.state.user_id == 1 + return Response(status_code=status.HTTP_204_NO_CONTENT) + + active_user = SimpleNamespace( + id=1, + username="activeuser", + is_active=True, + ) + valid_session = SimpleNamespace( + id=1, + user_id=1, + session_token_hash="validtokenhash", + expires_at=(datetime.now(timezone.utc) + timedelta(minutes=10)), + user=active_user, + ) + + with ( + patch("trading_journal.security.hash_session_token_sha256", return_value=valid_session.session_token_hash), + patch("trading_journal.crud.get_login_session_by_token_hash", return_value=valid_session), + patch("trading_journal.crud.update_login_session") as mock_update, + ): + response = asyncio.run(middleware.dispatch(request, call_next)) + assert response.status_code == status.HTTP_204_NO_CONTENT + mock_update.assert_called_once() + _, kwargs = mock_update.call_args + update_session = kwargs.get("update_session") + assert update_session is not None + assert (update_session.expires_at - datetime.now(timezone.utc)).total_seconds() > settings.session_expiry_seconds - 1 + assert (update_session.last_seen_at - datetime.now(timezone.utc)).total_seconds() < 1 + assert update_session.last_used_ip == "testclient" + assert update_session.user_agent == "test-agent" + + # --- User services ----------------------------------------------------------- -def test_register_user_success(): - pytest.fail("TODO: mock crud/security, assert UserRead username") +def test_register_user_success() -> None: + user_in = dto.UserCreate(username="newuser", password="newpassword") + user_in_with_hashed_password = { + "username": user_in.username, + "password_hash": "hashednewpassword", + } + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_user_by_username", return_value=None) as mock_get, + patch( + "trading_journal.crud.create_user", + return_value=SimpleNamespace(id=1, username=user_in.username, is_active=True), + ) as mock_create, + patch("trading_journal.security.hash_password", return_value=user_in_with_hashed_password["password_hash"]), + ): + user_out = service.register_user_service(db, user_in) + assert user_out.id is not None + assert user_out.username == user_in.username + mock_get.assert_called_once_with(db, user_in.username) + mock_create.assert_called_once_with(db, user_data=user_in_with_hashed_password) -def test_register_user_exists_raises(): - pytest.fail("TODO: mock get_user_by_username to return obj and expect UserAlreadyExistsError") +def test_register_user_exists_raises() -> None: + user_in = dto.UserCreate(username="existinguser", password="newpassword") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_user_by_username", + return_value=SimpleNamespace(id=1, username=user_in.username, is_active=True), + ) as mock_get, + ): + with pytest.raises(service.UserAlreadyExistsError) as exc_info: + service.register_user_service(db, user_in) + assert str(exc_info.value) == "username already exists" + mock_get.assert_called_once_with(db, user_in.username) -def test_authenticate_user_success(): - pytest.fail("TODO: mock crud/security, expect token + SessionsCreate DTO") +def test_authenticate_user_success() -> None: + user_in = dto.UserLogin(username="validuser", password="validpassword") + stored_user = SimpleNamespace(id=1, username=user_in.username, is_active=True, password_hash="hashedpassword") + expected_login_session = dto.SessionsCreate( + user_id=stored_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(seconds=settings.session_expiry_seconds), + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_user_by_username", + return_value=stored_user, + ) as mock_get, + patch("trading_journal.security.verify_password", return_value=True) as mock_verify, + patch("trading_journal.security.generate_session_token", return_value="newsessiontoken") as mock_token, + patch("trading_journal.security.hash_session_token_sha256", return_value="newsessiontokenhash") as mock_hash_session_token, + patch( + "trading_journal.crud.create_login_session", + return_value=SimpleNamespace(user_id=stored_user.id, expires_at=expected_login_session.expires_at), + ) as mock_create_session, + ): + user_out = service.authenticate_user_service(db, user_in) + assert user_out is not None + login_session, token = user_out + # assert fields instead of direct equality to avoid pydantic/model issues + assert getattr(login_session, "user_id", None) == stored_user.id + assert isinstance(getattr(login_session, "expires_at", None), datetime) + assert abs((login_session.expires_at - expected_login_session.expires_at).total_seconds()) < 2 + assert token == "newsessiontoken" + assert login_session.user_id == stored_user.id + mock_get.assert_called_once_with(db, user_in.username) + mock_verify.assert_called_once_with(user_in.password, stored_user.password_hash) + mock_token.assert_called_once() + mock_hash_session_token.assert_called_once_with("newsessiontoken") + mock_create_session.assert_called_once_with( + session=db, + user_id=stored_user.id, + session_token_hash="newsessiontokenhash", + session_length_seconds=settings.session_expiry_seconds, + ) -def test_authenticate_user_invalid_password_returns_none(): - pytest.fail("TODO: mock verify_password False") +def test_authenticate_user_not_found_returns_none() -> None: + user_in = dto.UserLogin(username="nonexistentuser", password="anypassword") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_user_by_username", + return_value=None, + ) as mock_get, + ): + user_out = service.authenticate_user_service(db, user_in) + assert user_out is None + mock_get.assert_called_once_with(db, user_in.username) + + +def test_authenticate_user_invalid_password_returns_none() -> None: + user_in = dto.UserLogin(username="validuser", password="invalidpassword") + stored_user = SimpleNamespace(id=1, username=user_in.username, is_active=True, password_hash="hashedpassword") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_user_by_username", + return_value=stored_user, + ) as mock_get, + patch("trading_journal.security.verify_password", return_value=False) as mock_verify, + ): + user_out = service.authenticate_user_service(db, user_in) + assert user_out is None + mock_get.assert_called_once_with(db, user_in.username) + mock_verify.assert_called_once_with(user_in.password, stored_user.password_hash) # --- Exchange services ------------------------------------------------------- -def test_create_exchange_duplicate_raises(): - pytest.fail("TODO: mock get_exchange_by_name_and_user_id and expect ExchangeAlreadyExistsError") +def test_create_exchange_duplicate_raises() -> None: + exchange_in = dto.ExchangesCreate(user_id=1, name="NYSE", notes="Test exchange") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_name_and_user_id", + return_value=SimpleNamespace(id=1, user_id=1, name=exchange_in.name, notes="Existing exchange"), + ) as mock_get, + ): + with pytest.raises(service.ExchangeAlreadyExistsError) as exc_info: + service.create_exchange_service(db, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes) + assert str(exc_info.value) == "Exchange with the same name already exists for this user" + mock_get.assert_called_once_with(db, exchange_in.name, exchange_in.user_id) -def test_update_exchange_not_found(): - pytest.fail("TODO: mock get_exchange_by_id None and expect ExchangeNotFoundError") +def test_create_exchange_success() -> None: + exchange_in = dto.ExchangesCreate(user_id=1, name="NASDAQ", notes="New exchange") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_name_and_user_id", + return_value=None, + ) as mock_get, + patch( + "trading_journal.crud.create_exchange", + return_value=SimpleNamespace(id=2, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes), + ) as mock_create, + ): + exchange_out = service.create_exchange_service(db, user_id=exchange_in.user_id, name=exchange_in.name, notes=exchange_in.notes) + assert exchange_out.name == exchange_in.name + assert exchange_out.notes == exchange_in.notes + mock_get.assert_called_once_with(db, exchange_in.name, exchange_in.user_id) + mock_create.assert_called_once_with(db, exchange_data=exchange_in) + + +def test_get_exchanges_by_user_id() -> None: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_all_exchanges_by_user_id", + return_value=[ + SimpleNamespace(id=1, user_id=1, name="NYSE", notes="First exchange"), + SimpleNamespace(id=2, user_id=1, name="NASDAQ", notes="Second exchange"), + ], + ) as mock_get, + ): + exchanges = service.get_exchanges_by_user_service(db, user_id=1) + assert len(exchanges) == 2 + assert exchanges[0].name == "NYSE" + assert exchanges[1].name == "NASDAQ" + mock_get.assert_called_once_with(db, 1) + + +def test_get_exchanges_by_user_no_exchanges() -> None: + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_all_exchanges_by_user_id", + return_value=[], + ) as mock_get, + ): + exchanges = service.get_exchanges_by_user_service(db, user_id=1) + assert len(exchanges) == 0 + mock_get.assert_called_once_with(db, 1) + + +def test_update_exchange_not_found() -> None: + exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_id", + return_value=None, + ) as mock_get, + ): + with pytest.raises(service.ExchangeNotFoundError) as exc_info: + service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + assert str(exc_info.value) == "Exchange not found" + mock_get.assert_called_once_with(db, 1) + + +def test_update_exchange_owner_mismatch_raises() -> None: + exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") + existing_exchange = SimpleNamespace(id=1, user_id=2, name="OldName", notes="Old notes") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_id", + return_value=existing_exchange, + ) as mock_get, + ): + with pytest.raises(service.ExchangeNotFoundError) as exc_info: + service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + assert str(exc_info.value) == "Exchange not found" + mock_get.assert_called_once_with(db, 1) + + +def test_update_exchange_duplication() -> None: + exchange_update = dto.ExchangesBase(name="DuplicateName", notes="Updated notes") + existing_exchange = SimpleNamespace(id=1, user_id=1, name="OldName", notes="Old notes") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_id", + return_value=existing_exchange, + ) as mock_get, + patch( + "trading_journal.crud.get_exchange_by_name_and_user_id", + return_value=SimpleNamespace(id=2, user_id=1, name="DuplicateName", notes="Another exchange"), + ) as mock_get_by_name, + ): + with pytest.raises(service.ExchangeAlreadyExistsError) as exc_info: + service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + assert str(exc_info.value) == "Another exchange with the same name already exists for this user" + mock_get.assert_called_once_with(db, 1) + mock_get_by_name.assert_called_once_with(db, "DuplicateName", 1) + + +def test_update_exchange_success() -> None: + exchange_update = dto.ExchangesBase(name="UpdatedName", notes="Updated notes") + existing_exchange = SimpleNamespace(id=1, user_id=1, name="OldName", notes="Old notes") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch( + "trading_journal.crud.get_exchange_by_id", + return_value=existing_exchange, + ) as mock_get, + patch( + "trading_journal.crud.get_exchange_by_name_and_user_id", + return_value=None, + ) as mock_get_by_name, + patch( + "trading_journal.crud.update_exchange", + return_value=SimpleNamespace(id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes), + ) as mock_update, + ): + exchange_out = service.update_exchanges_service(db, exchange_id=1, user_id=1, name=exchange_update.name, notes=exchange_update.notes) + assert exchange_out.name == exchange_update.name + assert exchange_out.notes == exchange_update.notes + mock_get.assert_called_once_with(db, 1) + mock_get_by_name.assert_called_once_with(db, "UpdatedName", 1) + mock_update.assert_called_once_with(db, 1, update_data=exchange_update) # --- Cycle services ---------------------------------------------------------- -def test_validate_cycle_update_rules(): - pytest.fail("TODO: call _validate_cycle_update_data with invalid combos") +def test_get_cycle_by_id_not_found_raises() -> None: + user_id = 1 + cycle_id = 1 + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=None) as mock_get, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.get_cycle_by_id_service(db, user_id=user_id, cycle_id=cycle_id) + assert str(exc_info.value) == "Cycle not found" + mock_get.assert_called_once_with(db, cycle_id) -def test_update_cycle_owner_mismatch_raises(): - pytest.fail("TODO: mock get_cycle_by_id owned by other user, expect CycleNotFoundError") +def test_get_cycle_by_id_owner_mismatch_raises() -> None: + user_id = 1 + cycle_id = 1 + cycle = SimpleNamespace(id=cycle_id, user_id=2) # Owned by different user + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=cycle) as mock_get, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.get_cycle_by_id_service(db, user_id=user_id, cycle_id=cycle_id) + assert str(exc_info.value) == "Cycle not found" + mock_get.assert_called_once_with(db, cycle_id) + + +def test_get_cycle_by_id_success() -> None: + user_id = 1 + cycle_id = 1 + cycle = SimpleNamespace( + id=cycle_id, + friendly_name="Test Cycle", + status="OPEN", + funding_source="MIXED", + user_id=user_id, + symbol="AAPL", + exchange_id=1, + underlying_currency="USD", + start_date=datetime.now(timezone.utc).date(), + trades=[], + exchange=None, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=cycle) as mock_get, + ): + cycle_out = service.get_cycle_by_id_service(db, user_id=user_id, cycle_id=cycle_id) + assert cycle_out.id == cycle_id + assert cycle_out.user_id == user_id + assert cycle_out.friendly_name == "Test Cycle" + assert cycle_out.status == "OPEN" + assert cycle_out.funding_source == "MIXED" + assert cycle_out.symbol == "AAPL" + assert cycle_out.exchange_id == 1 + assert cycle_out.underlying_currency == "USD" + assert cycle_out.trades == [] + mock_get.assert_called_once_with(db, cycle_id) + + +def test_get_cycles_by_user_no_cycles() -> None: + user_id = 1 + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycles_by_user_id", return_value=[]) as mock_get, + ): + cycles = service.get_cycles_by_user_service(db, user_id=user_id) + assert isinstance(cycles, list) + assert len(cycles) == 0 + mock_get.assert_called_once_with(db, user_id) + + +def test_get_cycles_by_user_with_cycles() -> None: + user_id = 1 + cycle1 = SimpleNamespace( + id=1, + friendly_name="Cycle 1", + status="OPEN", + funding_source="MIXED", + user_id=user_id, + symbol="AAPL", + exchange_id=1, + underlying_currency="USD", + start_date=datetime.now(timezone.utc).date(), + trades=[], + exchange=None, + ) + cycle2 = SimpleNamespace( + id=2, + friendly_name="Cycle 2", + status="CLOSED", + funding_source="LOAN", + user_id=user_id, + symbol="TSLA", + exchange_id=2, + underlying_currency="USD", + start_date=datetime.now(timezone.utc).date() - timedelta(days=30), + trades=[], + exchange=None, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycles_by_user_id", return_value=[cycle1, cycle2]) as mock_get, + ): + cycles = service.get_cycles_by_user_service(db, user_id=user_id) + assert isinstance(cycles, list) + assert len(cycles) == 2 + assert cycles[0].id == 1 + assert cycles[0].friendly_name == "Cycle 1" + assert cycles[1].id == 2 + assert cycles[1].friendly_name == "Cycle 2" + mock_get.assert_called_once_with(db, user_id) + + +def test_update_cycle_closed_status_mismatch_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="CLOSED") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "end_date is required when status is CLOSED" + + +def test_update_cycle_open_status_mismatch_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", end_date=datetime.now(timezone.utc).date()) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "end_date must be empty when status is OPEN" + + +def test_update_cycle_invalid_capital_exposure_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", capital_exposure_cents=-100) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "capital_exposure_cents must be non-negative" + + +def test_update_cycle_no_cash_no_loan_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", funding_source="LOAN", loan_amount_cents=None) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + + +def test_update_cycle_loan_missing_interest_raises() -> None: + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + funding_source="LOAN", + loan_amount_cents=10000, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "loan_amount_cents and loan_interest_rate_tenth_bps are required when funding_source is not CASH" + + +def test_update_cycle_loan_negative_loan_raises() -> None: + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + funding_source="LOAN", + loan_amount_cents=-10000, + loan_interest_rate_tenth_bps=50, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "loan_amount_cents must be non-negative" + + +def test_update_cycle_loan_negative_interest_raises() -> None: + cycle_data = dto.CycleUpdate( + id=1, + friendly_name="Updated Cycle", + status="OPEN", + funding_source="LOAN", + loan_amount_cents=10000, + loan_interest_rate_tenth_bps=-50, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidCycleDataError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "loan_interest_rate_tenth_bps must be non-negative" + + +def test_update_cycle_not_found_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN") + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=None) as mock_get, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "Cycle not found" + mock_get.assert_called_once_with(db, cycle_data.id) + + +def test_update_cycle_owner_mismatch_raises() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN") + existing_cycle = SimpleNamespace(id=1, user_id=2) # Owned by different user + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=existing_cycle) as mock_get, + ): + with pytest.raises(service.CycleNotFoundError) as exc_info: + service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert str(exc_info.value) == "Cycle not found" + mock_get.assert_called_once_with(db, cycle_data.id) + + +def test_update_cycle_success() -> None: + cycle_data = dto.CycleUpdate(id=1, friendly_name="Updated Cycle", status="OPEN", funding_source="CASH", capital_exposure_cents=5000) + existing_cycle = SimpleNamespace( + id=1, + user_id=1, + friendly_name="Old Cycle", + symbol="AAPL", + exchange_id=1, + underlying_currency="USD", + start_date=datetime.now(timezone.utc).date(), + status="OPEN", + funding_source="MIXED", + capital_exposure_cents=10000, + loan_amount_cents=2000, + loan_interest_rate_tenth_bps=50, + ) + updated_cycle = SimpleNamespace( + id=1, + user_id=1, + symbol="AAPL", + exchange_id=1, + underlying_currency="USD", + start_date=existing_cycle.start_date, + friendly_name="Updated Cycle", + status=cycle_data.status, + funding_source=cycle_data.funding_source, + capital_exposure_cents=cycle_data.capital_exposure_cents, + loan_amount_cents=None, + loan_interest_rate_tenth_bps=None, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_cycle_by_id", return_value=existing_cycle) as mock_get, + patch("trading_journal.crud.update_cycle", return_value=updated_cycle) as mock_update, + ): + cycle_out = service.update_cycle_service(db, user_id=1, cycle_data=cycle_data) + assert cycle_out.id == updated_cycle.id + assert cycle_out.friendly_name == updated_cycle.friendly_name + assert cycle_out.status == updated_cycle.status + assert cycle_out.funding_source == updated_cycle.funding_source + assert cycle_out.capital_exposure_cents == updated_cycle.capital_exposure_cents + assert cycle_out.loan_amount_cents is None + assert cycle_out.loan_interest_rate_tenth_bps is None + mock_get.assert_called_once_with(db, cycle_data.id) + update_cycle_base = dto.CycleBase( + friendly_name=cycle_data.friendly_name, + status=cycle_data.status, + funding_source=cycle_data.funding_source, + capital_exposure_cents=cycle_data.capital_exposure_cents, + loan_amount_cents=getattr(cycle_data, "loan_amount_cents", None), + loan_interest_rate_tenth_bps=getattr(cycle_data, "loan_interest_rate_tenth_bps", None), + end_date=getattr(cycle_data, "end_date", None), + ) + mock_update.assert_called_once_with(db, cycle_data.id, update_data=update_cycle_base) # --- Trade services ---------------------------------------------------------- -def test_create_trade_invalid_sell_requires_expiry(): - pytest.fail("TODO: build SELL_PUT without expiry/strike, expect InvalidTradeDataError") +def test_create_trade_short_option_no_strike() -> None: + trade_data = dto.TradeCreate( + user_id=1, + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.SELL_PUT, + trade_strategy=dto.TradeStrategy.WHEEL, + trade_date=datetime.now(timezone.utc).date(), + quantity=-1, + price_cents=5000, + commission_cents=100, + cycle_id=1, + friendly_name="Short Call", + notes="Test trade", + quantity_multiplier=100, + expiry_date=datetime.now(timezone.utc).date() + timedelta(days=30), + strike_price_cents=None, # Missing strike price + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + ): + with pytest.raises(service.InvalidTradeDataError) as exc_info: + service.create_trade_service(db, 1, trade_data) + assert str(exc_info.value) == "Invalid trade data: expiry_date and strike_price_cents are required for SELL_PUT and SELL_CALL trades" -def test_create_trade_appends_cashflow_and_calls_crud(): - pytest.fail("TODO: mock crud.create_trade, assert net_cash_flow_cents in result") +def test_create_trade_success() -> None: + trade_data = dto.TradeCreate( + user_id=1, + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.SELL_PUT, + trade_strategy=dto.TradeStrategy.WHEEL, + trade_date=datetime.now(timezone.utc).date(), + strike_price_cents=15000, + expiry_date=datetime.now(timezone.utc).date() + timedelta(days=30), + quantity=1, + price_cents=5000, + commission_cents=100, + cycle_id=1, + friendly_name="Sell put", + notes="Test trade", + quantity_multiplier=1, + ) + created_trade = SimpleNamespace( + id=1, + user_id=trade_data.user_id, + symbol=trade_data.symbol, + exchange_id=trade_data.exchange_id, + underlying_currency=trade_data.underlying_currency, + trade_type=trade_data.trade_type, + trade_strategy=trade_data.trade_strategy, + trade_date=trade_data.trade_date, + quantity=trade_data.quantity, + price_cents=trade_data.price_cents, + commission_cents=trade_data.commission_cents, + cycle_id=trade_data.cycle_id, + friendly_name=trade_data.friendly_name, + notes=trade_data.notes, + quantity_multiplier=trade_data.quantity_multiplier, + expiry_date=trade_data.expiry_date, + strike_price_cents=trade_data.strike_price_cents, + ) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.create_trade", return_value=created_trade) as mock_create_trade, + ): + trade_out = service.create_trade_service(db, user_id=1, trade_data=trade_data) + assert trade_out.id == created_trade.id + assert trade_out.user_id == created_trade.user_id + assert trade_out.symbol == created_trade.symbol + assert trade_out.trade_type == created_trade.trade_type + mock_create_trade.assert_called_once() + _, kwargs = mock_create_trade.call_args + passed_trade = kwargs.get("trade_data") or (mock_create_trade.call_args[0][1] if len(mock_create_trade.call_args[0]) > 1 else None) + assert passed_trade is not None + # expected for SELL_PUT: gross = quantity * price * quantity_multiplier (positive), net = gross - commission + expected_gross = trade_data.quantity * trade_data.price_cents * (trade_data.quantity_multiplier or 1) + expected_net = expected_gross - trade_data.commission_cents + assert getattr(passed_trade, "gross_cash_flow_cents", None) == expected_gross + assert getattr(passed_trade, "net_cash_flow_cents", None) == expected_net -def test_get_trade_by_id_missing_raises(): - pytest.fail("TODO: mock get_trade_by_id None, expect TradeNotFoundError") +def test_get_trade_by_id_not_found_when_missing() -> None: + with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.get_trade_by_id_service(db, user_id=1, trade_id=1) + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 1) -def test_update_trade_friendly_name_not_found(): - pytest.fail("TODO: mock get_trade_by_id None, expect TradeNotFoundError") +def test_get_trade_by_id_not_found_owner_mismatch() -> None: + existing_trade = SimpleNamespace(id=2, user_id=2) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + ): + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.get_trade_by_id_service(db, user_id=1, trade_id=2) + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 2) -def test_update_trade_note_sets_empty_string_when_none(): - pytest.fail("TODO: mock update_trade_note to return note '', assert DTO note") +def test_get_trade_by_id_success() -> None: + # build a trade-like object compatible with dto.TradeRead/model_validate + trade_obj = SimpleNamespace( + id=10, + user_id=1, + friendly_name="Test Trade", + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.LONG_SPOT, + trade_strategy=dto.TradeStrategy.SPOT, + trade_date=datetime.now(timezone.utc).date(), + trade_time_utc=None, + expiry_date=None, + strike_price_cents=None, + quantity=1, + quantity_multiplier=1, + price_cents=1000, + gross_cash_flow_cents=-1000, + commission_cents=10, + net_cash_flow_cents=-1010, + is_invalidated=False, + invalidated_at=None, + replaced_by_trade_id=None, + notes="ok", + cycle_id=None, + ) + with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=trade_obj) as mock_get: + res = service.get_trade_by_id_service(db, user_id=1, trade_id=10) + assert res.id == trade_obj.id + assert res.user_id == trade_obj.user_id + assert res.symbol == trade_obj.symbol + assert res.trade_type == trade_obj.trade_type + mock_get.assert_called_once_with(db, 10) + + +def test_update_trade_friendly_name_not_found() -> None: + with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Name") + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 10) + + +def test_update_trade_friendly_name_owner_mismatch_raises() -> None: + existing_trade = SimpleNamespace(id=10, user_id=2) # owned by another user + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + ): + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Name") + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 10) + + +def test_update_trade_friendly_name_success() -> None: + existing_trade = SimpleNamespace( + id=10, + user_id=1, + friendly_name="Old Name", + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.LONG_SPOT, + trade_strategy=dto.TradeStrategy.SPOT, + trade_date=datetime.now(timezone.utc).date(), + trade_time_utc=None, + expiry_date=None, + strike_price_cents=None, + quantity=1, + quantity_multiplier=1, + price_cents=1000, + gross_cash_flow_cents=-1000, + commission_cents=10, + net_cash_flow_cents=-1010, + is_invalidated=False, + invalidated_at=None, + replaced_by_trade_id=None, + notes="ok", + cycle_id=None, + ) + updated_trade = SimpleNamespace(**{**existing_trade.__dict__, "friendly_name": "New Friendly"}) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + patch("trading_journal.crud.update_trade_friendly_name", return_value=updated_trade) as mock_update, + ): + res = service.update_trade_friendly_name_service(db, user_id=1, trade_id=10, friendly_name="New Friendly") + assert res.friendly_name == "New Friendly" + mock_get.assert_called_once_with(db, 10) + mock_update.assert_called_once_with(db, 10, "New Friendly") + + +def test_update_trade_note_not_found() -> None: + with FakeDBFactory().get_session_ctx_manager() as db, patch("trading_journal.crud.get_trade_by_id", return_value=None) as mock_get: + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.update_trade_note_service(db, user_id=1, trade_id=20, note="x") + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 20) + + +def test_update_trade_note_owner_mismatch_raises() -> None: + existing_trade = SimpleNamespace(id=20, user_id=2) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + ): + with pytest.raises(service.TradeNotFoundError) as exc_info: + service.update_trade_note_service(db, user_id=1, trade_id=20, note="x") + assert str(exc_info.value) == "Trade not found" + mock_get.assert_called_once_with(db, 20) + + +def test_update_trade_note_success_and_none_becomes_empty() -> None: + existing_trade = SimpleNamespace( + id=20, + user_id=1, + friendly_name="Trade", + symbol="AAPL", + exchange_id=1, + underlying_currency=dto.UnderlyingCurrency.USD, + trade_type=dto.TradeType.LONG_SPOT, + trade_strategy=dto.TradeStrategy.SPOT, + trade_date=datetime.now(timezone.utc).date(), + trade_time_utc=None, + expiry_date=None, + strike_price_cents=None, + quantity=1, + quantity_multiplier=1, + price_cents=1000, + gross_cash_flow_cents=-1000, + commission_cents=10, + net_cash_flow_cents=-1010, + is_invalidated=False, + invalidated_at=None, + replaced_by_trade_id=None, + notes="old", + cycle_id=None, + ) + updated_trade = SimpleNamespace(**{**existing_trade.__dict__, "notes": ""}) + with ( + FakeDBFactory().get_session_ctx_manager() as db, + patch("trading_journal.crud.get_trade_by_id", return_value=existing_trade) as mock_get, + patch("trading_journal.crud.update_trade_note", return_value=updated_trade) as mock_update, + ): + res = service.update_trade_note_service(db, user_id=1, trade_id=20, note=None) + assert res.notes == "" + mock_get.assert_called_once_with(db, 20) + mock_update.assert_called_once_with(db, 20, "") diff --git a/backend/trading_journal/service.py b/backend/trading_journal/service.py index 4855c3e..79d487e 100644 --- a/backend/trading_journal/service.py +++ b/backend/trading_journal/service.py @@ -229,6 +229,7 @@ def update_exchanges_service(db_session: Session, user_id: int, exchange_id: int # Cycle Service def create_cycle_service(db_session: Session, user_id: int, cycle_data: CycleBase) -> CycleRead: + raise NotImplementedError("Cycle creation not implemented") cycle_data_dict = cycle_data.model_dump() cycle_data_dict["user_id"] = user_id cycle_data_with_user_id: CycleCreate = CycleCreate.model_validate(cycle_data_dict) -- 2.49.1 From 94fb4705ff5a25f7015ed0b76074fefc43cac29e Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 1 Oct 2025 15:53:48 +0200 Subject: [PATCH 18/18] add tests for router and openapi, still need to add routes for update interest --- backend/app.py | 12 +- backend/openapi.yaml | 554 ++++++++++++++++++++++++++++++++++++++ backend/tests/test_app.py | 404 ++++++++++++++++++++++++++- 3 files changed, 958 insertions(+), 12 deletions(-) create mode 100644 backend/openapi.yaml diff --git a/backend/app.py b/backend/app.py index e0e2799..2716a82 100644 --- a/backend/app.py +++ b/backend/app.py @@ -52,8 +52,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 await asyncio.to_thread(_db.dispose) +origins = [ + "http://127.0.0.1:18881", +] + app = FastAPI(lifespan=lifespan) -app.add_middleware(service.AuthMiddleWare) +app.add_middleware( + service.AuthMiddleWare, +) app.state.db_factory = _db @@ -77,7 +83,7 @@ async def register_user(request: Request, user_in: UserCreate) -> Response: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e except Exception as e: logger.exception("Failed to register user: \n") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e @app.post(f"{settings.settings.api_base}/login") @@ -110,7 +116,7 @@ async def login(request: Request, user_in: UserLogin) -> Response: ) except Exception as e: logger.exception("Failed to login user: \n") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") from e + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") from e else: return response diff --git a/backend/openapi.yaml b/backend/openapi.yaml new file mode 100644 index 0000000..f907c4e --- /dev/null +++ b/backend/openapi.yaml @@ -0,0 +1,554 @@ +openapi: "3.0.3" +info: + title: Trading Journal API + version: "1.0.0" + description: OpenAPI description generated from [`app.py`](app.py) and DTOs in [`trading_journal/dto.py`](trading_journal/dto.py). +servers: + - url: "http://127.0.0.1:18881{basePath}" + variables: + basePath: + default: "/api/v1" + description: "API base path (matches settings.settings.api_base)" +components: + securitySchemes: + session_cookie: + type: apiKey + in: cookie + name: session_token + schemas: + UserCreate: + $ref: "#/components/schemas/UserCreate_impl" + UserCreate_impl: + type: object + required: + - username + - password + properties: + username: + type: string + is_active: + type: boolean + default: true + password: + type: string + UserLogin: + type: object + required: + - username + - password + properties: + username: + type: string + password: + type: string + UserRead: + type: object + required: + - id + - username + properties: + id: + type: integer + username: + type: string + is_active: + type: boolean + SessionsBase: + type: object + required: + - user_id + properties: + user_id: + type: integer + SessionsCreate: + allOf: + - $ref: "#/components/schemas/SessionsBase" + - type: object + required: + - expires_at + properties: + expires_at: + type: string + format: date-time + ExchangesBase: + type: object + required: + - name + properties: + name: + type: string + notes: + type: string + nullable: true + ExchangesRead: + allOf: + - $ref: "#/components/schemas/ExchangesBase" + - type: object + required: + - id + properties: + id: + type: integer + CycleBase: + type: object + properties: + friendly_name: + type: string + nullable: true + status: + type: string + end_date: + type: string + format: date + nullable: true + funding_source: + type: string + nullable: true + capital_exposure_cents: + type: integer + nullable: true + loan_amount_cents: + type: integer + nullable: true + loan_interest_rate_tenth_bps: + type: integer + nullable: true + trades: + type: array + items: + $ref: "#/components/schemas/TradeRead" + nullable: true + exchange: + $ref: "#/components/schemas/ExchangesRead" + nullable: true + CycleCreate: + allOf: + - $ref: "#/components/schemas/CycleBase" + - type: object + required: + - user_id + - symbol + - exchange_id + - underlying_currency + - start_date + properties: + user_id: + type: integer + symbol: + type: string + exchange_id: + type: integer + underlying_currency: + type: string + start_date: + type: string + format: date + CycleUpdate: + allOf: + - $ref: "#/components/schemas/CycleBase" + - type: object + required: + - id + properties: + id: + type: integer + CycleRead: + allOf: + - $ref: "#/components/schemas/CycleCreate" + - type: object + required: + - id + properties: + id: + type: integer + TradeBase: + type: object + required: + - symbol + - underlying_currency + - trade_type + - trade_strategy + - trade_date + - quantity + - price_cents + - commission_cents + properties: + friendly_name: + type: string + nullable: true + symbol: + type: string + exchange_id: + type: integer + underlying_currency: + type: string + trade_type: + type: string + trade_strategy: + type: string + trade_date: + type: string + format: date + quantity: + type: integer + price_cents: + type: integer + commission_cents: + type: integer + notes: + type: string + nullable: true + cycle_id: + type: integer + nullable: true + TradeCreate: + allOf: + - $ref: "#/components/schemas/TradeBase" + - type: object + properties: + user_id: + type: integer + nullable: true + trade_time_utc: + type: string + format: date-time + nullable: true + gross_cash_flow_cents: + type: integer + nullable: true + net_cash_flow_cents: + type: integer + nullable: true + quantity_multiplier: + type: integer + default: 1 + expiry_date: + type: string + format: date + nullable: true + strike_price_cents: + type: integer + nullable: true + is_invalidated: + type: boolean + default: false + invalidated_at: + type: string + format: date-time + nullable: true + replaced_by_trade_id: + type: integer + nullable: true + TradeNoteUpdate: + type: object + required: + - id + properties: + id: + type: integer + notes: + type: string + nullable: true + TradeFriendlyNameUpdate: + type: object + required: + - id + - friendly_name + properties: + id: + type: integer + friendly_name: + type: string + TradeRead: + allOf: + - $ref: "#/components/schemas/TradeCreate" + - type: object + required: + - id + properties: + id: + type: integer +paths: + /status: + get: + summary: "Get API status" + security: [] # no auth required + responses: + "200": + description: OK + content: + application/json: + schema: + type: object + properties: + status: + type: string + /register: + post: + summary: "Register user" + security: [] # no auth required + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UserCreate" + responses: + "201": + description: Created + content: + application/json: + schema: + $ref: "#/components/schemas/UserRead" + "400": + description: Bad Request (user exists) + "500": + description: Internal Server Error + /login: + post: + summary: "Login" + security: [] # no auth required + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UserLogin" + responses: + "200": + description: OK (sets session cookie) + content: + application/json: + schema: + $ref: "#/components/schemas/SessionsBase" + headers: + Set-Cookie: + description: session cookie + schema: + type: string + "401": + description: Unauthorized + "500": + description: Internal Server Error + /exchanges: + post: + summary: "Create exchange" + security: + - session_cookie: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ExchangesBase" + responses: + "201": + description: Created + content: + application/json: + schema: + $ref: "#/components/schemas/ExchangesRead" + "400": + description: Bad Request + "401": + description: Unauthorized + get: + summary: "List user exchanges" + security: + - session_cookie: [] + responses: + "200": + description: OK + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/ExchangesRead" + "401": + description: Unauthorized + /exchanges/{exchange_id}: + patch: + summary: "Update exchange" + security: + - session_cookie: [] + parameters: + - name: exchange_id + in: path + required: true + schema: + type: integer + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ExchangesBase" + responses: + "200": + description: Updated + content: + application/json: + schema: + $ref: "#/components/schemas/ExchangesRead" + "404": + description: Not found + "400": + description: Bad request + /cycles: + post: + summary: "Create cycle (currently returns 405 in code)" + security: + - session_cookie: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CycleBase" + responses: + "405": + description: Method not allowed (app currently returns 405) + patch: + summary: "Update cycle" + security: + - session_cookie: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CycleUpdate" + responses: + "200": + description: Updated + content: + application/json: + schema: + $ref: "#/components/schemas/CycleRead" + "400": + description: Invalid data + "404": + description: Not found + /cycles/{cycle_id}: + get: + summary: "Get cycle by id" + security: + - session_cookie: [] + parameters: + - name: cycle_id + in: path + required: true + schema: + type: integer + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CycleRead" + "404": + description: Not found + /cycles/user/{user_id}: + get: + summary: "Get cycles by user id" + security: + - session_cookie: [] + parameters: + - name: user_id + in: path + required: true + schema: + type: integer + responses: + "200": + description: OK + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/CycleRead" + /trades: + post: + summary: "Create trade" + security: + - session_cookie: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/TradeCreate" + responses: + "201": + description: Created + content: + application/json: + schema: + $ref: "#/components/schemas/TradeRead" + "400": + description: Invalid trade data + "500": + description: Internal Server Error + /trades/{trade_id}: + get: + summary: "Get trade by id" + security: + - session_cookie: [] + parameters: + - name: trade_id + in: path + required: true + schema: + type: integer + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/TradeRead" + "404": + description: Not found + /trades/friendlyname: + patch: + summary: "Update trade friendly name" + security: + - session_cookie: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/TradeFriendlyNameUpdate" + responses: + "200": + description: Updated + content: + application/json: + schema: + $ref: "#/components/schemas/TradeRead" + "404": + description: Not found + /trades/notes: + patch: + summary: "Update trade notes" + security: + - session_cookie: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/TradeNoteUpdate" + responses: + "200": + description: Updated + content: + application/json: + schema: + $ref: "#/components/schemas/TradeRead" + "404": + description: Not found diff --git a/backend/tests/test_app.py b/backend/tests/test_app.py index 78cf8ad..cbc8553 100644 --- a/backend/tests/test_app.py +++ b/backend/tests/test_app.py @@ -1,19 +1,405 @@ -from collections.abc import Generator +from collections.abc import Callable +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import MagicMock import pytest +from fastapi import FastAPI, status +from fastapi.responses import JSONResponse from fastapi.testclient import TestClient import settings -from app import app +import trading_journal.service as svc @pytest.fixture -def client() -> Generator[TestClient, None, None]: - with TestClient(app) as client: - yield client +def client_factory(monkeypatch: pytest.MonkeyPatch) -> Callable[..., TestClient]: + class NoAuth: + def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002 + self.app = app + + async def __call__(self, scope, receive, send) -> None: # noqa: ANN001 + state = scope.get("state") + if state is None: + scope["state"] = SimpleNamespace() + scope["state"]["user_id"] = 1 + await self.app(scope, receive, send) + + class DeclineAuth: + def __init__(self, app: FastAPI, **opts) -> None: # noqa: ANN003, ARG002 + self.app = app + + async def __call__(self, scope, receive, send) -> None: # noqa: ANN001 + if scope.get("type") != "http": + await self.app(scope, receive, send) + return + path = scope.get("path", "") + # allow public/exempt paths through + if getattr(svc, "EXCEPT_PATHS", []) and path in svc.EXCEPT_PATHS: + await self.app(scope, receive, send) + return + # immediately respond 401 for protected paths + resp = JSONResponse({"detail": "Unauthorized"}, status_code=status.HTTP_401_UNAUTHORIZED) + await resp(scope, receive, send) + + def _factory(*, decline_auth: bool = False, **mocks: dict) -> TestClient: + defaults = { + "register_user_service": MagicMock(return_value=SimpleNamespace(model_dump=lambda: {"id": 1, "username": "mock"})), + "authenticate_user_service": MagicMock( + return_value=(SimpleNamespace(user_id=1, expires_at=(datetime.now(timezone.utc) + timedelta(hours=1))), "token"), + ), + "create_exchange_service": MagicMock( + return_value=SimpleNamespace(model_dump=lambda: {"name": "Binance", "notes": "some note", "user_id": 1}), + ), + "get_exchanges_by_user_service": MagicMock(return_value=[]), + } + + if decline_auth: + monkeypatch.setattr(svc, "AuthMiddleWare", DeclineAuth) + else: + monkeypatch.setattr(svc, "AuthMiddleWare", NoAuth) + merged = {**defaults, **mocks} + for name, mock in merged.items(): + monkeypatch.setattr(svc, name, mock) + import sys + + if "app" in sys.modules: + del sys.modules["app"] + from importlib import import_module + + app = import_module("app").app # re-import app module + + return TestClient(app) + + return _factory -def test_get_status(client: TestClient) -> None: - response = client.get(f"{settings.settings.api_base}/status") - assert response.status_code == 200 - assert response.json() == {"status": "ok"} +def test_get_status(client_factory: Callable[..., TestClient]) -> None: + client = client_factory() + with client as c: + response = c.get(f"{settings.settings.api_base}/status") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_register_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory() # use defaults + with client as c: + r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"}) + assert r.status_code == 201 + + +def test_register_user_already_exists(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(register_user_service=MagicMock(side_effect=svc.UserAlreadyExistsError("username already exists"))) + with client as c: + r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"}) + assert r.status_code == status.HTTP_400_BAD_REQUEST + assert r.json() == {"detail": "username already exists"} + + +def test_register_user_internal_server_error(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(register_user_service=MagicMock(side_effect=Exception("db is down"))) + with client as c: + r = c.post(f"{settings.settings.api_base}/register", json={"username": "a", "password": "b"}) + assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert r.json() == {"detail": "Internal Server Error"} + + +def test_login_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory() # use defaults + with client as c: + r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"}) + assert r.status_code == 200 + assert r.json() == {"user_id": 1} + assert r.cookies.get("session_token") == "token" + + +def test_login_failed_auth(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(authenticate_user_service=MagicMock(return_value=None)) + with client as c: + r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"}) + assert r.status_code == status.HTTP_401_UNAUTHORIZED + assert r.json() == {"detail": "Invalid username or password, or user doesn't exist"} + + +def test_login_internal_server_error(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(authenticate_user_service=MagicMock(side_effect=Exception("db is down"))) + with client as c: + r = c.post(f"{settings.settings.api_base}/login", json={"username": "a", "password": "b"}) + assert r.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert r.json() == {"detail": "Internal Server Error"} + + +def test_create_exchange_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory() + with client as c: + r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"}) + assert r.status_code == 201 + assert r.json() == {"user_id": 1, "name": "Binance", "notes": "some note"} + + +def test_create_exchange_already_exists(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(create_exchange_service=MagicMock(side_effect=svc.ExchangeAlreadyExistsError("exchange already exists"))) + with client as c: + r = c.post(f"{settings.settings.api_base}/exchanges", json={"name": "Binance"}) + assert r.status_code == status.HTTP_400_BAD_REQUEST + assert r.json() == {"detail": "exchange already exists"} + + +def test_get_exchanges_unauthenticated(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(decline_auth=True) + with client as c: + r = c.get(f"{settings.settings.api_base}/exchanges") + assert r.status_code == status.HTTP_401_UNAUTHORIZED + assert r.json() == {"detail": "Unauthorized"} + + +def test_get_exchanges_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory() + with client as c: + r = c.get(f"{settings.settings.api_base}/exchanges") + assert r.status_code == 200 + assert r.json() == [] + + +def test_update_exchanges_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + update_exchanges_service=MagicMock( + return_value=SimpleNamespace(model_dump=lambda: {"user_id": 1, "name": "BinanceUS", "notes": "updated note"}), + ), + ) + with client as c: + r = c.patch(f"{settings.settings.api_base}/exchanges/1", json={"name": "BinanceUS", "notes": "updated note"}) + assert r.status_code == 200 + assert r.json() == {"user_id": 1, "name": "BinanceUS", "notes": "updated note"} + + +def test_update_exchanges_not_found(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(update_exchanges_service=MagicMock(side_effect=svc.ExchangeNotFoundError("exchange not found"))) + with client as c: + r = c.patch(f"{settings.settings.api_base}/exchanges/999", json={"name": "NonExistent", "notes": "no note"}) + assert r.status_code == status.HTTP_404_NOT_FOUND + assert r.json() == {"detail": "exchange not found"} + + +def test_get_cycles_by_id_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + get_cycle_by_id_service=MagicMock( + return_value=SimpleNamespace( + friendly_name="Cycle 1", + status="active", + id=1, + ), + ), + ) + with client as c: + r = c.get(f"{settings.settings.api_base}/cycles/1") + assert r.status_code == 200 + assert r.json() == {"id": 1, "friendly_name": "Cycle 1", "status": "active"} + + +def test_get_cycles_by_id_not_found(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(get_cycle_by_id_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found"))) + with client as c: + r = c.get(f"{settings.settings.api_base}/cycles/999") + assert r.status_code == status.HTTP_404_NOT_FOUND + assert r.json() == {"detail": "cycle not found"} + + +def test_get_cycles_by_user_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + get_cycles_by_user_service=MagicMock( + return_value=[ + SimpleNamespace( + friendly_name="Cycle 1", + status="active", + id=1, + ), + SimpleNamespace( + friendly_name="Cycle 2", + status="completed", + id=2, + ), + ], + ), + ) + with client as c: + r = c.get(f"{settings.settings.api_base}/cycles/user/1") + assert r.status_code == 200 + assert r.json() == [ + {"id": 1, "friendly_name": "Cycle 1", "status": "active"}, + {"id": 2, "friendly_name": "Cycle 2", "status": "completed"}, + ] + + +def test_update_cycles_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + update_cycle_service=MagicMock( + return_value=SimpleNamespace( + friendly_name="Updated Cycle", + status="completed", + id=1, + ), + ), + ) + with client as c: + r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "Updated Cycle", "status": "completed", "id": 1}) + assert r.status_code == 200 + assert r.json() == {"id": 1, "friendly_name": "Updated Cycle", "status": "completed"} + + +def test_update_cycles_invalid_cycle_data(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + update_cycle_service=MagicMock(side_effect=svc.InvalidCycleDataError("invalid cycle data")), + ) + with client as c: + r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "", "status": "unknown", "id": 1}) + assert r.status_code == status.HTTP_400_BAD_REQUEST + assert r.json() == {"detail": "invalid cycle data"} + + +def test_update_cycles_not_found(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(update_cycle_service=MagicMock(side_effect=svc.CycleNotFoundError("cycle not found"))) + with client as c: + r = c.patch(f"{settings.settings.api_base}/cycles", json={"friendly_name": "NonExistent", "status": "active", "id": 999}) + assert r.status_code == status.HTTP_404_NOT_FOUND + assert r.json() == {"detail": "cycle not found"} + + +def test_create_trade_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + create_trade_service=MagicMock( + return_value=SimpleNamespace(), + ), + ) + with client as c: + r = c.post( + f"{settings.settings.api_base}/trades", + json={ + "cycle_id": 1, + "exchange_id": 1, + "symbol": "BTCUSD", + "underlying_currency": "USD", + "trade_type": "LONG_SPOT", + "trade_strategy": "FX", + "quantity": 1, + "price_cents": 15, + "commission_cents": 100, + "trade_date": "2025-10-01", + }, + ) + assert r.status_code == 201 + + +def test_create_trade_invalid_trade_data(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + create_trade_service=MagicMock(side_effect=svc.InvalidTradeDataError("invalid trade data")), + ) + with client as c: + r = c.post( + f"{settings.settings.api_base}/trades", + json={ + "cycle_id": 1, + "exchange_id": 1, + "symbol": "BTCUSD", + "underlying_currency": "USD", + "trade_type": "LONG_SPOT", + "trade_strategy": "FX", + "quantity": 1, + "price_cents": 15, + "commission_cents": 100, + "trade_date": "2025-10-01", + }, + ) + assert r.status_code == status.HTTP_400_BAD_REQUEST + assert r.json() == {"detail": "invalid trade data"} + + +def test_get_trade_by_id_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + get_trade_by_id_service=MagicMock( + return_value=SimpleNamespace( + id=1, + cycle_id=1, + exchange_id=1, + symbol="BTCUSD", + underlying_currency="USD", + trade_type="LONG_SPOT", + trade_strategy="FX", + quantity=1, + price_cents=1500, + commission_cents=100, + trade_date=datetime(2025, 10, 1, tzinfo=timezone.utc), + ), + ), + ) + with client as c: + r = c.get(f"{settings.settings.api_base}/trades/1") + assert r.status_code == 200 + assert r.json() == { + "id": 1, + "cycle_id": 1, + "exchange_id": 1, + "symbol": "BTCUSD", + "underlying_currency": "USD", + "trade_type": "LONG_SPOT", + "trade_strategy": "FX", + "quantity": 1, + "price_cents": 1500, + "commission_cents": 100, + "trade_date": "2025-10-01T00:00:00+00:00", + } + + +def test_get_trade_by_id_not_found(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(get_trade_by_id_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found"))) + with client as c: + r = c.get(f"{settings.settings.api_base}/trades/999") + assert r.status_code == status.HTTP_404_NOT_FOUND + assert r.json() == {"detail": "trade not found"} + + +def test_update_trade_friendly_name_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + update_trade_friendly_name_service=MagicMock( + return_value=SimpleNamespace( + id=1, + friendly_name="Updated Trade Name", + ), + ), + ) + with client as c: + r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 1, "friendly_name": "Updated Trade Name"}) + assert r.status_code == 200 + assert r.json() == {"id": 1, "friendly_name": "Updated Trade Name"} + + +def test_update_trade_friendly_name_not_found(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(update_trade_friendly_name_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found"))) + with client as c: + r = c.patch(f"{settings.settings.api_base}/trades/friendlyname", json={"id": 999, "friendly_name": "NonExistent Trade"}) + assert r.status_code == status.HTTP_404_NOT_FOUND + assert r.json() == {"detail": "trade not found"} + + +def test_update_trade_note_success(client_factory: Callable[..., TestClient]) -> None: + client = client_factory( + update_trade_note_service=MagicMock( + return_value=SimpleNamespace( + id=1, + note="Updated trade note", + ), + ), + ) + with client as c: + r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 1, "note": "Updated trade note"}) + assert r.status_code == 200 + assert r.json() == {"id": 1, "note": "Updated trade note"} + + +def test_update_trade_note_not_found(client_factory: Callable[..., TestClient]) -> None: + client = client_factory(update_trade_note_service=MagicMock(side_effect=svc.TradeNotFoundError("trade not found"))) + with client as c: + r = c.patch(f"{settings.settings.api_base}/trades/notes", json={"id": 999, "note": "NonExistent Trade Note"}) + assert r.status_code == status.HTTP_404_NOT_FOUND + assert r.json() == {"detail": "trade not found"} -- 2.49.1