From 2c22f20b487f7294fa4437ad883c01295a6dd2d8 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sun, 14 Sep 2025 21:01:12 +0200 Subject: [PATCH] continue on crud --- backend/tests/test_crud.py | 264 ++++++++++++++++++++++++++- backend/trading_journal/crud.py | 128 ++++++++++++- backend/trading_journal/models.py | 44 ++++- backend/trading_journal/models_v1.py | 21 ++- 4 files changed, 435 insertions(+), 22 deletions(-) diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 199d493..8fb3c49 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -9,6 +9,8 @@ from sqlmodel import Session, SQLModel from trading_journal import crud, models +# TODO: If needed, add failing flow tests, but now only add happy flow. + @pytest.fixture def engine() -> Generator[Engine, None, None]: @@ -45,7 +47,7 @@ def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int: user_id=user_id, friendly_name=friendly_name, symbol="AAPL", - underlying_currency="USD", + underlying_currency=models.UnderlyingCurrency.USD, status=models.CycleStatus.OPEN, start_date=datetime.now().date(), ) @@ -55,7 +57,40 @@ def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int: return cycle.id -def test_create_trade_success(session: Session): +def make_trade( + session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade" +) -> int: + trade = models.Trades( + user_id=user_id, + friendly_name=friendly_name, + symbol="AAPL", + underlying_currency=models.UnderlyingCurrency.USD, + trade_type=models.TradeType.LONG_SPOT, + trade_strategy=models.TradeStrategy.SPOT, + trade_date=datetime.now().date(), + trade_time_utc=datetime.now(), + quantity=10, + price_cents=15000, + gross_cash_flow_cents=-150000, + commission_cents=500, + net_cash_flow_cents=-150500, + cycle_id=cycle_id, + ) + session.add(trade) + session.commit() + session.refresh(trade) + return trade.id + + +def make_trade_by_trade_data(session, trade_data: dict) -> int: + trade = models.Trades(**trade_data) + session.add(trade) + session.commit() + session.refresh(trade) + return trade.id + + +def test_create_trade_success_with_cycle(session: Session): user_id = make_user(session) cycle_id = make_cycle(session, user_id) @@ -79,3 +114,228 @@ def test_create_trade_success(session: Session): assert trade.id is not None assert trade.user_id == user_id assert trade.cycle_id == cycle_id + session.refresh(trade) + + actual_trade = session.get(models.Trades, trade.id) + assert actual_trade is not None + assert actual_trade.friendly_name == trade_data["friendly_name"] + assert actual_trade.symbol == trade_data["symbol"] + assert actual_trade.underlying_currency == trade_data["underlying_currency"] + assert actual_trade.trade_type == trade_data["trade_type"] + assert actual_trade.trade_strategy == trade_data["trade_strategy"] + assert actual_trade.quantity == trade_data["quantity"] + assert actual_trade.price_cents == trade_data["price_cents"] + assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"] + assert actual_trade.commission_cents == trade_data["commission_cents"] + assert actual_trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"] + assert actual_trade.cycle_id == trade_data["cycle_id"] + + +def test_create_trade_with_auto_created_cycle(session: Session): + user_id = make_user(session) + + trade_data = { + "user_id": user_id, + "friendly_name": "Test Trade with Auto Cycle", + "symbol": "AAPL", + "underlying_currency": "USD", + "trade_type": "LONG_SPOT", + "trade_strategy": "SPOT", + "trade_time_utc": datetime.now(), + "quantity": 5, + "price_cents": 15500, + } + + trade = crud.create_trade(session, trade_data) + assert trade.id is not None + assert trade.user_id == user_id + assert trade.cycle_id is not None + session.refresh(trade) + + actual_trade = session.get(models.Trades, trade.id) + assert actual_trade is not None + assert actual_trade.friendly_name == trade_data["friendly_name"] + assert actual_trade.symbol == trade_data["symbol"] + assert actual_trade.underlying_currency == trade_data["underlying_currency"] + assert actual_trade.trade_type == trade_data["trade_type"] + assert actual_trade.trade_strategy == trade_data["trade_strategy"] + assert actual_trade.quantity == trade_data["quantity"] + assert actual_trade.price_cents == trade_data["price_cents"] + assert actual_trade.cycle_id == trade.cycle_id + + # Verify the auto-created cycle + auto_cycle = session.get(models.Cycles, trade.cycle_id) + assert auto_cycle is not None + assert auto_cycle.user_id == user_id + assert auto_cycle.symbol == trade_data["symbol"] + assert auto_cycle.underlying_currency == trade_data["underlying_currency"] + assert auto_cycle.status == models.CycleStatus.OPEN + assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade") + + +def test_create_trade_missing_required_fields(session: Session): + user_id = make_user(session) + + base_trade_data = { + "user_id": user_id, + "friendly_name": "Incomplete Trade", + "symbol": "AAPL", + "underlying_currency": "USD", + "trade_type": "LONG_SPOT", + "trade_strategy": "SPOT", + "trade_time_utc": datetime.now(), + "quantity": 10, + "price_cents": 15000, + } + + # Missing symbol + trade_data = base_trade_data.copy() + trade_data.pop("symbol", None) + with pytest.raises(ValueError) as excinfo: + crud.create_trade(session, trade_data) + assert "symbol is required" in str(excinfo.value) + + # Missing underlying_currency + trade_data = base_trade_data.copy() + trade_data.pop("underlying_currency", None) + with pytest.raises(ValueError) as excinfo: + crud.create_trade(session, trade_data) + assert "underlying_currency is required" in str(excinfo.value) + + # Missing trade_type + trade_data = base_trade_data.copy() + trade_data.pop("trade_type", None) + with pytest.raises(ValueError) as excinfo: + crud.create_trade(session, trade_data) + assert "trade_type is required" in str(excinfo.value) + + # Missing trade_strategy + trade_data = base_trade_data.copy() + trade_data.pop("trade_strategy", None) + with pytest.raises(ValueError) as excinfo: + crud.create_trade(session, trade_data) + assert "trade_strategy is required" in str(excinfo.value) + + # Missing quantity + trade_data = base_trade_data.copy() + trade_data.pop("quantity", None) + with pytest.raises(ValueError) as excinfo: + crud.create_trade(session, trade_data) + assert "quantity is required" in str(excinfo.value) + + # Missing price_cents + trade_data = base_trade_data.copy() + trade_data.pop("price_cents", None) + with pytest.raises(ValueError) as excinfo: + crud.create_trade(session, trade_data) + assert "price_cents is required" in str(excinfo.value) + + +def test_get_trade_by_id(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id) + trade_data = { + "user_id": user_id, + "friendly_name": "Test Trade for Get", + "symbol": "AAPL", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, + "trade_date": datetime.now().date(), + "trade_time_utc": datetime.now(), + "quantity": 10, + "price_cents": 15000, + "gross_cash_flow_cents": -150000, + "commission_cents": 500, + "net_cash_flow_cents": -150500, + "cycle_id": cycle_id, + } + trade_id = make_trade_by_trade_data(session, trade_data) + trade = crud.get_trade_by_id(session, trade_id) + assert trade is not None + assert trade.id == trade_id + assert trade.friendly_name == trade_data["friendly_name"] + assert trade.symbol == trade_data["symbol"] + assert trade.underlying_currency == trade_data["underlying_currency"] + assert trade.trade_type == trade_data["trade_type"] + assert trade.trade_strategy == trade_data["trade_strategy"] + assert trade.quantity == trade_data["quantity"] + assert trade.price_cents == trade_data["price_cents"] + assert trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"] + assert trade.commission_cents == trade_data["commission_cents"] + assert trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"] + assert trade.cycle_id == trade_data["cycle_id"] + assert trade.trade_date == trade_data["trade_date"] + + +def test_get_trade_by_user_id_and_friendly_name(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id) + friendly_name = "Unique Trade Name" + trade_data = { + "user_id": user_id, + "friendly_name": friendly_name, + "symbol": "AAPL", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, + "trade_date": datetime.now().date(), + "trade_time_utc": datetime.now(), + "quantity": 10, + "price_cents": 15000, + "gross_cash_flow_cents": -150000, + "commission_cents": 500, + "net_cash_flow_cents": -150500, + "cycle_id": cycle_id, + } + make_trade_by_trade_data(session, trade_data) + trade = crud.get_trade_by_user_id_and_friendly_name(session, user_id, friendly_name) + assert trade is not None + assert trade.friendly_name == friendly_name + assert trade.user_id == user_id + + +def test_create_cycle(session: Session): + user_id = make_user(session) + cycle_data = { + "user_id": user_id, + "friendly_name": "My First Cycle", + "symbol": "GOOGL", + "underlying_currency": "USD", + "status": models.CycleStatus.OPEN, + "start_date": datetime.now().date(), + } + cycle = crud.create_cycle(session, cycle_data) + assert cycle.id is not None + assert cycle.user_id == user_id + assert cycle.friendly_name == cycle_data["friendly_name"] + assert cycle.symbol == cycle_data["symbol"] + assert cycle.underlying_currency == cycle_data["underlying_currency"] + assert cycle.status == cycle_data["status"] + assert cycle.start_date == cycle_data["start_date"] + + session.refresh(cycle) + actual_cycle = session.get(models.Cycles, cycle.id) + assert actual_cycle is not None + assert actual_cycle.friendly_name == cycle_data["friendly_name"] + assert actual_cycle.symbol == cycle_data["symbol"] + assert actual_cycle.underlying_currency == cycle_data["underlying_currency"] + assert actual_cycle.status == cycle_data["status"] + assert actual_cycle.start_date == cycle_data["start_date"] + + +def test_create_user(session: Session): + user_data = { + "username": "newuser", + "password_hash": "newhashedpassword", + } + user = crud.create_user(session, user_data) + assert user.id is not None + assert user.username == user_data["username"] + assert user.password_hash == user_data["password_hash"] + + session.refresh(user) + actual_user = session.get(models.Users, user.id) + assert actual_user is not None + assert actual_user.username == user_data["username"] + assert actual_user.password_hash == user_data["password_hash"] diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 99a4b8e..3515c8d 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -1,12 +1,13 @@ +from datetime import datetime, timezone from typing import Mapping from sqlalchemy.exc import IntegrityError -from sqlmodel import Session +from sqlmodel import Session, select from trading_journal import models -def _coerce_enum(enum_cls, value, field_name: str): +def _check_enum(enum_cls, value, field_name: str): if value is None: raise ValueError(f"{field_name} is required") # already an enum member @@ -29,29 +30,66 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: data = dict(trade_data) allowed = {c.name for c in models.Trades.__table__.columns} payload = {k: v for k, v in data.items() if k in allowed} + if "symbol" not in payload: + raise ValueError("symbol is required") + if "underlying_currency" not in payload: + raise ValueError("underlying_currency is required") + payload["underlying_currency"] = _check_enum( + models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency" + ) if "trade_type" not in payload: raise ValueError("trade_type is required") - payload["trade_type"] = _coerce_enum( + payload["trade_type"] = _check_enum( models.TradeType, payload["trade_type"], "trade_type" ) - if "trade_strategy" not in payload: raise ValueError("trade_strategy is required") - payload["trade_strategy"] = _coerce_enum( + payload["trade_strategy"] = _check_enum( models.TradeStrategy, payload["trade_strategy"], "trade_strategy" ) + now = datetime.now(timezone.utc) + payload.pop("trade_time_utc", None) + payload["trade_time_utc"] = now + if "trade_date" not in payload or payload.get("trade_date") is None: + payload["trade_date"] = payload["trade_time_utc"].date() cycle_id = payload.get("cycle_id") user_id = payload.get("user_id") - + if "quantity" not in payload: + raise ValueError("quantity is required") + if "price_cents" not in payload: + raise ValueError("price_cents is required") + if "commission_cents" not in payload: + payload["commission_cents"] = 0 + quantity: int = payload["quantity"] + price_cents: int = payload["price_cents"] + commission_cents: int = payload["commission_cents"] + if "gross_cash_flow_cents" not in payload: + payload["gross_cash_flow_cents"] = -quantity * price_cents + if "net_cash_flow_cents" not in payload: + payload["net_cash_flow_cents"] = ( + payload["gross_cash_flow_cents"] - commission_cents + ) + if cycle_id is None: + cycle_id = create_cycle( + session, + { + "user_id": user_id, + "symbol": payload["symbol"], + "underlying_currency": payload["underlying_currency"], + "friendly_name": "Auto-created Cycle by trade " + + payload.get("friendly_name", ""), + "status": models.CycleStatus.OPEN, + "start_date": payload["trade_date"], + }, + ).id + payload["cycle_id"] = cycle_id if cycle_id is not None: cycle = session.get(models.Cycles, cycle_id) if cycle is None: - pass # TODO: create a cycle with basic info here + raise ValueError("cycle_id does not exist") else: if cycle.user_id != user_id: raise ValueError("cycle.user_id does not match trade.user_id") - else: - raise ValueError("trade must have a cycle_id.") t = models.Trades(**payload) session.add(t) try: @@ -61,3 +99,75 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades: raise ValueError("create_trade integrity error") from e session.refresh(t) return t + + +def get_trade_by_id(session: Session, trade_id: int) -> models.Trades | None: + return session.get(models.Trades, trade_id) + + +def get_trade_by_user_id_and_friendly_name( + session: Session, user_id: int, friendly_name: str +) -> models.Trades | None: + statement = select(models.Trades).where( + models.Trades.user_id == user_id, + models.Trades.friendly_name == friendly_name, + ) + return session.exec(statement).first() + + +# Cycles +def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: + if hasattr(cycle_data, "dict"): + data = cycle_data.dict(exclude_unset=True) + else: + data = dict(cycle_data) + allowed = {c.name for c in models.Cycles.__table__.columns} + payload = {k: v for k, v in data.items() if k in allowed} + if "user_id" not in payload: + raise ValueError("user_id is required") + if "symbol" not in payload: + raise ValueError("symbol is required") + if "underlying_currency" not in payload: + raise ValueError("underlying_currency is required") + payload["underlying_currency"] = _check_enum( + models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency" + ) + if "status" not in payload: + raise ValueError("status is required") + payload["status"] = _check_enum(models.CycleStatus, payload["status"], "status") + if "start_date" not in payload: + raise ValueError("start_date is required") + + c = models.Cycles(**payload) + session.add(c) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("create_cycle integrity error") from e + session.refresh(c) + return c + + +# Users +def create_user(session: Session, user_data: Mapping) -> models.Users: + if hasattr(user_data, "dict"): + data = user_data.dict(exclude_unset=True) + else: + data = dict(user_data) + allowed = {c.name for c in models.Users.__table__.columns} + payload = {k: v for k, v in data.items() if k in allowed} + if "username" not in payload: + raise ValueError("username is required") + if "password_hash" not in payload: + raise ValueError("password_hash is required") + + u = models.Users(**payload) + session.add(u) + try: + session.flush() + except IntegrityError as e: + session.rollback() + raise ValueError("create_user integrity error") from e + session.refresh(u) + return u diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index ee0e18b..b916e1b 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -1,8 +1,17 @@ from datetime import date, datetime # noqa: TC003 from enum import Enum -from sqlalchemy import Date, Text, UniqueConstraint -from sqlmodel import Column, DateTime, Field, Relationship, SQLModel +from sqlmodel import ( + Column, + Date, + DateTime, + Field, + Integer, + Relationship, + SQLModel, + Text, + UniqueConstraint, +) class TradeType(str, Enum): @@ -36,6 +45,18 @@ class CycleStatus(str, Enum): CLOSED = "CLOSED" +class UnderlyingCurrency(str, Enum): + EUR = "EUR" + USD = "USD" + GBP = "GBP" + JPY = "JPY" + AUD = "AUD" + CAD = "CAD" + CHF = "CHF" + NZD = "NZD" + CNY = "CNY" + + class FundingSource(str, Enum): CASH = "CASH" MARGIN = "MARGIN" @@ -57,19 +78,22 @@ class Trades(SQLModel, table=True): default=None, sa_column=Column(Text, nullable=True) ) symbol: str = Field(sa_column=Column(Text, nullable=False)) - underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: UnderlyingCurrency = Field( + sa_column=Column(Text, nullable=False) + ) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) + trade_date: date = Field(sa_column=Column(Date, nullable=False)) trade_time_utc: datetime = Field( sa_column=Column(DateTime(timezone=True), nullable=False) ) expiry_date: date | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True) - quantity: int - price_cents: int - gross_cash_flow_cents: int - commission_cents: int - net_cash_flow_cents: int + quantity: int = Field(sa_column=Column(Integer, nullable=False)) + price_cents: int = Field(sa_column=Column(Integer, nullable=False)) + gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) + commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) + net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) cycle_id: int | None = Field( default=None, foreign_key="cycles.id", nullable=True, index=True ) @@ -90,7 +114,9 @@ class Cycles(SQLModel, table=True): default=None, sa_column=Column(Text, nullable=True) ) symbol: str = Field(sa_column=Column(Text, nullable=False)) - underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: UnderlyingCurrency = Field( + sa_column=Column(Text, nullable=False) + ) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True) diff --git a/backend/trading_journal/models_v1.py b/backend/trading_journal/models_v1.py index ee0e18b..b7f3696 100644 --- a/backend/trading_journal/models_v1.py +++ b/backend/trading_journal/models_v1.py @@ -36,6 +36,18 @@ class CycleStatus(str, Enum): CLOSED = "CLOSED" +class UnderlyingCurrency(str, Enum): + EUR = "EUR" + USD = "USD" + GBP = "GBP" + JPY = "JPY" + AUD = "AUD" + CAD = "CAD" + CHF = "CHF" + NZD = "NZD" + CNY = "CNY" + + class FundingSource(str, Enum): CASH = "CASH" MARGIN = "MARGIN" @@ -57,9 +69,12 @@ class Trades(SQLModel, table=True): default=None, sa_column=Column(Text, nullable=True) ) symbol: str = Field(sa_column=Column(Text, nullable=False)) - underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: UnderlyingCurrency = Field( + sa_column=Column(Text, nullable=False) + ) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) + trade_date: date = Field(sa_column=Column(Date, nullable=False)) trade_time_utc: datetime = Field( sa_column=Column(DateTime(timezone=True), nullable=False) ) @@ -90,7 +105,9 @@ class Cycles(SQLModel, table=True): default=None, sa_column=Column(Text, nullable=True) ) symbol: str = Field(sa_column=Column(Text, nullable=False)) - underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) + underlying_currency: UnderlyingCurrency = Field( + sa_column=Column(Text, nullable=False) + ) status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) capital_exposure_cents: int | None = Field(default=None, nullable=True)