diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py new file mode 100644 index 0000000..9f5bf1f --- /dev/null +++ b/backend/tests/test_crud.py @@ -0,0 +1,76 @@ +from collections.abc import Generator +from datetime import datetime + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel + +from trading_journal import crud, models + + +@pytest.fixture +def engine() -> Engine: + e = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(e) + return e + + +@pytest.fixture +def session(engine: Engine) -> Generator[Session, None, None]: + with Session(engine) as s: + yield s + + +def make_user(session: Session, username: str = "testuser") -> int: + user = models.Users(username=username, password_hash="hashedpassword") + session.add(user) + session.commit() + session.refresh(user) + return user.id + + +def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int: + cycle = models.Cycles( + user_id=user_id, + friendly_name=friendly_name, + symbol="AAPL", + underlying_currency="USD", + status="open", + start_date=datetime.now().date(), + ) + session.add(cycle) + session.commit() + session.refresh(cycle) + return cycle.id + + +def test_create_trade_success(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id) + + trade_data = { + "user_id": user_id, + "friendly_name": "Test Trade", + "symbol": "AAPL", + "underlying_currency": "USD", + "trade_type": "LONG_SPOT", + "trade_strategy": "SPOT", + "trade_time_utc": datetime.now(), + "quantity": 10, + "price_cents": 15000, + "gross_cash_flow_cents": -150000, + "commission_cents": 500, + "net_cash_flow_cents": -150500, + "cycle_id": cycle_id, + } + + trade = crud.create_trade(session, trade_data) + assert trade.id is not None + assert trade.user_id == user_id + assert trade.cycle_id == cycle_id diff --git a/backend/tests/test_db_migration.py b/backend/tests/test_db_migration.py index d61768d..2ea7fee 100644 --- a/backend/tests/test_db_migration.py +++ b/backend/tests/test_db_migration.py @@ -36,8 +36,8 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None: "symbol": ("TEXT", 1, 0), "underlying_currency": ("TEXT", 1, 0), "status": ("TEXT", 1, 0), - "funding_source": ("TEXT", 1, 0), - "capital_exposure_cents": ("INTEGER", 1, 0), + "funding_source": ("TEXT", 0, 0), + "capital_exposure_cents": ("INTEGER", 0, 0), "loan_amount_cents": ("INTEGER", 0, 0), "loan_interest_rate_bps": ("INTEGER", 0, 0), "start_date": ("DATE", 1, 0), diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 8b13789..99a4b8e 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -1 +1,63 @@ +from typing import Mapping +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session + +from trading_journal import models + + +def _coerce_enum(enum_cls, value, field_name: str): + if value is None: + raise ValueError(f"{field_name} is required") + # already an enum member + if isinstance(value, enum_cls): + return value + # strict string match: must match exactly enum name or enum value (case-sensitive) + if isinstance(value, str): + for m in enum_cls: + if m.name == value or str(m.value) == value: + return m + allowed = [m.name for m in enum_cls] + raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}") + + +# Trades +def create_trade(session: Session, trade_data: Mapping) -> models.Trades: + if hasattr(trade_data, "dict"): + data = trade_data.dict(exclude_unset=True) + else: + data = dict(trade_data) + allowed = {c.name for c in models.Trades.__table__.columns} + payload = {k: v for k, v in data.items() if k in allowed} + if "trade_type" not in payload: + raise ValueError("trade_type is required") + payload["trade_type"] = _coerce_enum( + models.TradeType, payload["trade_type"], "trade_type" + ) + + if "trade_strategy" not in payload: + raise ValueError("trade_strategy is required") + payload["trade_strategy"] = _coerce_enum( + models.TradeStrategy, payload["trade_strategy"], "trade_strategy" + ) + cycle_id = payload.get("cycle_id") + user_id = payload.get("user_id") + + if cycle_id is not None: + cycle = session.get(models.Cycles, cycle_id) + if cycle is None: + pass # TODO: create a cycle with basic info here + else: + if cycle.user_id != user_id: + raise ValueError("cycle.user_id does not match trade.user_id") + else: + raise ValueError("trade must have a cycle_id.") + t = models.Trades(**payload) + session.add(t) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("create_trade integrity error") from e + session.refresh(t) + return t diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index 745bbbd..ee0e18b 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from datetime import date, datetime # noqa: TC003 from enum import Enum @@ -75,7 +73,7 @@ class Trades(SQLModel, table=True): cycle_id: int | None = Field( default=None, foreign_key="cycles.id", nullable=True, index=True ) - cycle: Cycles | None = Relationship(back_populates="trades") + cycle: "Cycles" = Relationship(back_populates="trades") class Cycles(SQLModel, table=True): @@ -94,13 +92,13 @@ class Cycles(SQLModel, table=True): symbol: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) - funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False)) - capital_exposure_cents: int + funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) + capital_exposure_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True) loan_interest_rate_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) - trades: list[Trades] = Relationship(back_populates="cycle") + trades: list["Trades"] = Relationship(back_populates="cycle") class Users(SQLModel, table=True): diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index 745bbbd..ee0e18b 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from datetime import date, datetime # noqa: TC003 from enum import Enum @@ -75,7 +73,7 @@ class Trades(SQLModel, table=True): cycle_id: int | None = Field( default=None, foreign_key="cycles.id", nullable=True, index=True ) - cycle: Cycles | None = Relationship(back_populates="trades") + cycle: "Cycles" = Relationship(back_populates="trades") class Cycles(SQLModel, table=True): @@ -94,13 +92,13 @@ class Cycles(SQLModel, table=True): symbol: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) - funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False)) - capital_exposure_cents: int + funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) + capital_exposure_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True) loan_interest_rate_bps: int | None = Field(default=None, nullable=True) start_date: date = Field(sa_column=Column(Date, nullable=False)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) - trades: list[Trades] = Relationship(back_populates="cycle") + trades: list["Trades"] = Relationship(back_populates="cycle") class Users(SQLModel, table=True):