From 479d5cd230b3841353ae36a45f1cea65a6aa1967 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sun, 14 Sep 2025 15:40:11 +0200 Subject: [PATCH] add user table --- backend/.gitignore | 4 + backend/tests/test_db_migration.py | 115 +++++++++++++++++------- backend/trading_journal/crud.py | 1 + backend/trading_journal/db.py | 40 +++++++-- backend/trading_journal/db_migration.py | 15 +++- backend/trading_journal/models.py | 78 ++++++++++++---- backend/trading_journal/models_v1.py | 78 ++++++++++++---- 7 files changed, 253 insertions(+), 78 deletions(-) create mode 100644 backend/trading_journal/crud.py diff --git a/backend/.gitignore b/backend/.gitignore index 51f9037..6321b92 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -11,3 +11,7 @@ venv.bak/ __pycache__/ .pytest_cache/ + +*.db +*.db-shm +*.db-wal \ No newline at end of file diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index f68bdc4..d61768d 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -18,20 +18,59 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: connect_args={"check_same_thread": False}, poolclass=StaticPool, ) - - # ensure target is the LATEST_VERSION we expect for the test monkeypatch.setattr(db_migration, "LATEST_VERSION", 1) - - # run real migrations (will import trading_journal.models_v1 inside _mig_0_1) final_version = db_migration.run_migrations(engine) assert final_version == 1 - # import snapshot models to validate schema - from trading_journal import models_v1 + expected_schema = { + "users": { + "id": ("INTEGER", 1, 1), + "username": ("TEXT", 1, 0), + "password_hash": ("TEXT", 1, 0), + "is_active": ("BOOLEAN", 1, 0), + }, + "cycles": { + "id": ("INTEGER", 1, 1), + "user_id": ("INTEGER", 1, 0), + "friendly_name": ("TEXT", 0, 0), + "symbol": ("TEXT", 1, 0), + "underlying_currency": ("TEXT", 1, 0), + "status": ("TEXT", 1, 0), + "funding_source": ("TEXT", 1, 0), + "capital_exposure_cents": ("INTEGER", 1, 0), + "loan_amount_cents": ("INTEGER", 0, 0), + "loan_interest_rate_bps": ("INTEGER", 0, 0), + "start_date": ("DATE", 1, 0), + "end_date": ("DATE", 0, 0), + }, + "trades": { + "id": ("INTEGER", 1, 1), + "user_id": ("INTEGER", 1, 0), + "friendly_name": ("TEXT", 0, 0), + "symbol": ("TEXT", 1, 0), + "underlying_currency": ("TEXT", 1, 0), + "trade_type": ("TEXT", 1, 0), + "trade_strategy": ("TEXT", 1, 0), + "trade_time_utc": ("DATETIME", 1, 0), + "expiry_date": ("DATE", 0, 0), + "strike_price_cents": ("INTEGER", 0, 0), + "quantity": ("INTEGER", 1, 0), + "price_cents": ("INTEGER", 1, 0), + "gross_cash_flow_cents": ("INTEGER", 1, 0), + "commission_cents": ("INTEGER", 1, 0), + "net_cash_flow_cents": ("INTEGER", 1, 0), + "cycle_id": ("INTEGER", 0, 0), + }, + } - expected_tables = { - "trades": models_v1.Trades.__table__, - "cycles": models_v1.Cycles.__table__, + expected_fks = { + "trades": [ + {"table": "cycles", "from": "cycle_id", "to": "id"}, + {"table": "users", "from": "user_id", "to": "id"}, + ], + "cycles": [ + {"table": "users", "from": "user_id", "to": "id"}, + ], } with engine.connect() as conn: @@ -40,8 +79,8 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: text("SELECT name FROM sqlite_master WHERE type='table'") ).fetchall() found_tables = {r[0] for r in rows} - assert set(expected_tables.keys()).issubset(found_tables), ( - f"missing tables: {set(expected_tables.keys()) - found_tables}" + assert set(expected_schema.keys()).issubset(found_tables), ( + f"missing tables: {set(expected_schema.keys()) - found_tables}" ) # check user_version @@ -49,29 +88,37 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: assert uv is not None assert int(uv[0]) == 1 - # validate columns and (base) types for each expected table - dialect = conn.dialect - for tbl_name, table in expected_tables.items(): + # validate each table columns + for tbl_name, cols in expected_schema.items(): info_rows = conn.execute(text(f"PRAGMA table_info({tbl_name})")).fetchall() - # build mapping: column name -> declared type (upper) - actual_cols = {r[1]: (r[2] or "").upper() for r in info_rows} - for col in table.columns: - assert col.name in actual_cols, ( - f"column {col.name} missing in table {tbl_name}" + # map: name -> (type, notnull, pk) + actual = { + r[1]: ((r[2] or "").upper(), int(r[3]), int(r[5])) for r in info_rows + } + for colname, (exp_type, exp_notnull, exp_pk) in cols.items(): + assert colname in actual, f"{tbl_name}: missing column {colname}" + act_type, act_notnull, act_pk = actual[colname] + # compare base type (e.g. VARCHAR(13) -> VARCHAR) + if act_type: + act_base = _base_type_of(act_type) + else: + act_base = "" + assert exp_type in act_base or act_base in exp_type, ( + f"type mismatch {tbl_name}.{colname}: expected {exp_type}, got {act_base}" ) - # compile expected type against this dialect - try: - compiled = col.type.compile( - dialect=dialect - ) # e.g. VARCHAR(13), DATETIME - except Exception: - compiled = str(col.type) - expected_base = _base_type_of(compiled) - actual_type = actual_cols[col.name] - actual_base = _base_type_of(actual_type) if actual_type else "" - # accept either direction (some dialect vs sqlite naming differences) - assert (expected_base in actual_base) or ( - actual_base in expected_base - ), ( - f"type mismatch for {tbl_name}.{col.name}: expected {expected_base}, got {actual_base}" + assert act_notnull == exp_notnull, ( + f"notnull mismatch {tbl_name}.{colname}: expected {exp_notnull}, got {act_notnull}" ) + assert act_pk == exp_pk, ( + f"pk mismatch {tbl_name}.{colname}: expected {exp_pk}, got {act_pk}" + ) + for tbl_name, fks in expected_fks.items(): + fk_rows = conn.execute( + text(f"PRAGMA foreign_key_list('{tbl_name}')") + ).fetchall() + # fk_rows columns: (id, seq, table, from, to, on_update, on_delete, match) + actual_fk_list = [ + {"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows + ] + for efk in fks: + assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}" diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/backend/trading_journal/crud.py @@ -0,0 +1 @@ + diff --git a/backend/trading_journal/db.py b/backend/trading_journal/db.py index 229284a..d09a53d 100644 --- a/backend/trading_journal/db.py +++ b/backend/trading_journal/db.py @@ -15,22 +15,43 @@ if TYPE_CHECKING: class Database: - def __init__(self, database_url: str | None = None, *, echo: bool = False, connect_args: dict | None = None) -> None: + def __init__( + self, + database_url: str | None = None, + *, + echo: bool = False, + connect_args: dict | None = None, + ) -> None: self._database_url = database_url or "sqlite:///:memory:" - default_connect = {"check_same_thread": False, "timeout": 30} if self._database_url.startswith("sqlite") else {} + default_connect = ( + {"check_same_thread": False, "timeout": 30} + if self._database_url.startswith("sqlite") + else {} + ) merged_connect = {**default_connect, **(connect_args or {})} if self._database_url == "sqlite:///:memory:": logger = logging.getLogger(__name__) - logger.warning("Using in-memory SQLite database; all data will be lost when the application stops.") - self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect, poolclass=StaticPool) + logger.warning( + "Using in-memory SQLite database; all data will be lost when the application stops." + ) + self._engine = create_engine( + self._database_url, + echo=echo, + connect_args=merged_connect, + poolclass=StaticPool, + ) else: - self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect) + self._engine = create_engine( + self._database_url, echo=echo, connect_args=merged_connect + ) if self._database_url.startswith("sqlite"): - def _enable_sqlite_pragmas(dbapi_conn: DBAPIConnection, _connection_record: object) -> None: + def _enable_sqlite_pragmas( + dbapi_conn: DBAPIConnection, _connection_record: object + ) -> None: try: cur = dbapi_conn.cursor() cur.execute("PRAGMA journal_mode=WAL;") @@ -62,5 +83,10 @@ class Database: self._engine.dispose() -def create_database(database_url: str | None = None, *, echo: bool = False, connect_args: dict | None = None) -> Database: +def create_database( + database_url: str | None = None, + *, + echo: bool = False, + connect_args: dict | None = None, +) -> Database: return Database(database_url, echo=echo, connect_args=connect_args) diff --git a/backend/trading_journal/db_migration.py b/backend/trading_journal/db_migration.py index fe9b0ba..c59e3b0 100644 --- a/backend/trading_journal/db_migration.py +++ b/backend/trading_journal/db_migration.py @@ -18,9 +18,16 @@ def _mig_0_1(engine: Engine) -> None: """ # Ensure all models are imported before this is called (import side-effect registers tables) # e.g. trading_journal.models is imported in the caller / app startup. - from trading_journal import models_v1 # noqa: PLC0415, F401 + from trading_journal import models_v1 - SQLModel.metadata.create_all(bind=engine) + SQLModel.metadata.create_all( + bind=engine, + tables=[ + models_v1.Trades.__table__, + models_v1.Cycles.__table__, + models_v1.Users.__table__, + ], + ) # map current_version -> function that migrates from current_version -> current_version+1 @@ -51,7 +58,9 @@ def run_migrations(engine: Engine, target_version: int | None = None) -> int: while cur_version < target: fn = MIGRATIONS.get(cur_version) if fn is None: - raise RuntimeError(f"No migration from {cur_version} -> {cur_version + 1}") + raise RuntimeError( + f"No migration from {cur_version} -> {cur_version + 1}" + ) # call migration with Engine (fn should use transactions) fn(engine) _set_sqlite_user_version(conn, cur_version + 1) diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 7382741..745bbbd 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -3,8 +3,8 @@ from __future__ import annotations from datetime import date, datetime # noqa: TC003 from enum import Enum +from sqlalchemy import Date, Text, UniqueConstraint from sqlmodel import Column, DateTime, Field, Relationship, SQLModel -from sqlmodel import Enum as SQLEnum class TradeType(str, Enum): @@ -12,6 +12,18 @@ class TradeType(str, Enum): ASSIGNMENT = "ASSIGNMENT" SELL_CALL = "SELL_CALL" EXERCISE_CALL = "EXERCISE_CALL" + LONG_SPOT = "LONG_SPOT" + CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT" + SHORT_SPOT = "SHORT_SPOT" + CLOSE_SHORT_SPOT = "CLOSE_SHORT_SPOT" + LONG_CFD = "LONG_CFD" + CLOSE_LONG_CFD = "CLOSE_LONG_CFD" + SHORT_CFD = "SHORT_CFD" + CLOSE_SHORT_CFD = "CLOSE_SHORT_CFD" + LONG_OTHER = "LONG_OTHER" + CLOSE_LONG_OTHER = "CLOSE_LONG_OTHER" + SHORT_OTHER = "SHORT_OTHER" + CLOSE_SHORT_OTHER = "CLOSE_SHORT_OTHER" class TradeStrategy(str, Enum): @@ -34,13 +46,25 @@ class FundingSource(str, Enum): class Trades(SQLModel, table=True): __tablename__ = "trades" - id: str | None = Field(default=None, primary_key=True) - user_id: str - symbol: str - underlying_currency: str - trade_type: TradeType = Field(sa_column=Column(SQLEnum(TradeType, name="trade_type_enum"), nullable=False)) - trade_strategy: TradeStrategy = Field(sa_column=Column(SQLEnum(TradeStrategy, name="trade_strategy_enum"), nullable=False)) - trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + __table_args__ = ( + UniqueConstraint( + "user_id", "friendly_name", name="uq_trades_user_friendly_name" + ), + ) + + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint + friendly_name: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) + symbol: str = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) + trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) + trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) + trade_time_utc: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int @@ -48,21 +72,41 @@ class Trades(SQLModel, table=True): gross_cash_flow_cents: int commission_cents: int net_cash_flow_cents: int - cycle_id: str | None = Field(default=None, foreign_key="cycles.id", nullable=True) + cycle_id: int | None = Field( + default=None, foreign_key="cycles.id", nullable=True, index=True + ) cycle: Cycles | None = Relationship(back_populates="trades") class Cycles(SQLModel, table=True): __tablename__ = "cycles" - id: str | None = Field(default=None, primary_key=True) - user_id: str - symbol: str - underlying_currency: str - start_date: date - end_date: date | None = Field(default=None, nullable=True) - status: CycleStatus = Field(sa_column=Column(SQLEnum(CycleStatus, name="cycle_status_enum"), nullable=False)) - funding_source: FundingSource = Field(sa_column=Column(SQLEnum(FundingSource, name="funding_source_enum"), nullable=False)) + __table_args__ = ( + UniqueConstraint( + "user_id", "friendly_name", name="uq_cycles_user_friendly_name" + ), + ) + + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + friendly_name: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) + symbol: str = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) + status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) + funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False)) capital_exposure_cents: int loan_amount_cents: int | None = Field(default=None, nullable=True) loan_interest_rate_bps: int | None = Field(default=None, nullable=True) + start_date: date = Field(sa_column=Column(Date, nullable=False)) + end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) trades: list[Trades] = Relationship(back_populates="cycle") + + +class Users(SQLModel, table=True): + __tablename__ = "users" + id: int | None = Field(default=None, primary_key=True) + # unique=True already creates an index; no need to also set index=True + username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + password_hash: str = Field(sa_column=Column(Text, nullable=False)) + is_active: bool = Field(default=True, nullable=False) diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 7382741..745bbbd 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -3,8 +3,8 @@ from __future__ import annotations from datetime import date, datetime # noqa: TC003 from enum import Enum +from sqlalchemy import Date, Text, UniqueConstraint from sqlmodel import Column, DateTime, Field, Relationship, SQLModel -from sqlmodel import Enum as SQLEnum class TradeType(str, Enum): @@ -12,6 +12,18 @@ class TradeType(str, Enum): ASSIGNMENT = "ASSIGNMENT" SELL_CALL = "SELL_CALL" EXERCISE_CALL = "EXERCISE_CALL" + LONG_SPOT = "LONG_SPOT" + CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT" + SHORT_SPOT = "SHORT_SPOT" + CLOSE_SHORT_SPOT = "CLOSE_SHORT_SPOT" + LONG_CFD = "LONG_CFD" + CLOSE_LONG_CFD = "CLOSE_LONG_CFD" + SHORT_CFD = "SHORT_CFD" + CLOSE_SHORT_CFD = "CLOSE_SHORT_CFD" + LONG_OTHER = "LONG_OTHER" + CLOSE_LONG_OTHER = "CLOSE_LONG_OTHER" + SHORT_OTHER = "SHORT_OTHER" + CLOSE_SHORT_OTHER = "CLOSE_SHORT_OTHER" class TradeStrategy(str, Enum): @@ -34,13 +46,25 @@ class FundingSource(str, Enum): class Trades(SQLModel, table=True): __tablename__ = "trades" - id: str | None = Field(default=None, primary_key=True) - user_id: str - symbol: str - underlying_currency: str - trade_type: TradeType = Field(sa_column=Column(SQLEnum(TradeType, name="trade_type_enum"), nullable=False)) - trade_strategy: TradeStrategy = Field(sa_column=Column(SQLEnum(TradeStrategy, name="trade_strategy_enum"), nullable=False)) - trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) + __table_args__ = ( + UniqueConstraint( + "user_id", "friendly_name", name="uq_trades_user_friendly_name" + ), + ) + + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + # allow null while user may omit friendly_name; uniqueness enforced per-user by constraint + friendly_name: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) + symbol: str = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) + trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) + trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) + trade_time_utc: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False) + ) expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) quantity: int @@ -48,21 +72,41 @@ class Trades(SQLModel, table=True): gross_cash_flow_cents: int commission_cents: int net_cash_flow_cents: int - cycle_id: str | None = Field(default=None, foreign_key="cycles.id", nullable=True) + cycle_id: int | None = Field( + default=None, foreign_key="cycles.id", nullable=True, index=True + ) cycle: Cycles | None = Relationship(back_populates="trades") class Cycles(SQLModel, table=True): __tablename__ = "cycles" - id: str | None = Field(default=None, primary_key=True) - user_id: str - symbol: str - underlying_currency: str - start_date: date - end_date: date | None = Field(default=None, nullable=True) - status: CycleStatus = Field(sa_column=Column(SQLEnum(CycleStatus, name="cycle_status_enum"), nullable=False)) - funding_source: FundingSource = Field(sa_column=Column(SQLEnum(FundingSource, name="funding_source_enum"), nullable=False)) + __table_args__ = ( + UniqueConstraint( + "user_id", "friendly_name", name="uq_cycles_user_friendly_name" + ), + ) + + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + friendly_name: str | None = Field( + default=None, sa_column=Column(Text, nullable=True) + ) + symbol: str = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) + status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) + funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False)) capital_exposure_cents: int loan_amount_cents: int | None = Field(default=None, nullable=True) loan_interest_rate_bps: int | None = Field(default=None, nullable=True) + start_date: date = Field(sa_column=Column(Date, nullable=False)) + end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) trades: list[Trades] = Relationship(back_populates="cycle") + + +class Users(SQLModel, table=True): + __tablename__ = "users" + id: int | None = Field(default=None, primary_key=True) + # unique=True already creates an index; no need to also set index=True + username: str = Field(sa_column=Column(Text, nullable=False, unique=True)) + password_hash: str = Field(sa_column=Column(Text, nullable=False)) + is_active: bool = Field(default=True, nullable=False)