From d1f1b3e66c21c75a9dc10684c2b3aeab3debdf7a Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 17 Sep 2025 16:36:56 +0200 Subject: [PATCH] Add invalidate not yet with tests --- backend/tests/test_crud.py | 44 +++++++++++++++++++++++++++++++ backend/trading_journal/crud.py | 7 +++++ backend/trading_journal/models.py | 8 ++++++ 3 files changed, 59 insertions(+) diff --git a/backend/tests/test_crud.py b/backend/tests/test_crud.py index 6e8700c..325e270 100644 --- a/backend/tests/test_crud.py +++ b/backend/tests/test_crud.py @@ -294,6 +294,50 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session): assert trade.user_id == user_id +def test_get_trades_by_user_id(session: Session): + user_id = make_user(session) + cycle_id = make_cycle(session, user_id) + trade_data_1 = { + "user_id": user_id, + "friendly_name": "Trade One", + "symbol": "AAPL", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.LONG_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, + "trade_date": datetime.now().date(), + "trade_time_utc": datetime.now(), + "quantity": 10, + "price_cents": 15000, + "gross_cash_flow_cents": -150000, + "commission_cents": 500, + "net_cash_flow_cents": -150500, + "cycle_id": cycle_id, + } + trade_data_2 = { + "user_id": user_id, + "friendly_name": "Trade Two", + "symbol": "GOOGL", + "underlying_currency": models.UnderlyingCurrency.USD, + "trade_type": models.TradeType.SHORT_SPOT, + "trade_strategy": models.TradeStrategy.SPOT, + "trade_date": datetime.now().date(), + "trade_time_utc": datetime.now(), + "quantity": 5, + "price_cents": 280000, + "gross_cash_flow_cents": 1400000, + "commission_cents": 700, + "net_cash_flow_cents": 1399300, + "cycle_id": cycle_id, + } + make_trade_by_trade_data(session, trade_data_1) + make_trade_by_trade_data(session, trade_data_2) + + trades = crud.get_trades_by_user_id(session, user_id) + assert len(trades) == 2 + friendly_names = {trade.friendly_name for trade in trades} + assert friendly_names == {"Trade One", "Trade Two"} + + def test_create_cycle(session: Session): user_id = make_user(session) cycle_data = { diff --git a/backend/trading_journal/crud.py b/backend/trading_journal/crud.py index 3515c8d..b37e065 100644 --- a/backend/trading_journal/crud.py +++ b/backend/trading_journal/crud.py @@ -115,6 +115,13 @@ def get_trade_by_user_id_and_friendly_name( return session.exec(statement).first() +def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades]: + statement = select(models.Trades).where( + models.Trades.user_id == user_id, + ) + return session.exec(statement).all() + + # Cycles def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles: if hasattr(cycle_data, "dict"): diff --git a/backend/trading_journal/models.py b/backend/trading_journal/models.py index d9d8944..0e5857f 100644 --- a/backend/trading_journal/models.py +++ b/backend/trading_journal/models.py @@ -94,6 +94,14 @@ class Trades(SQLModel, table=True): gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) commission_cents: int = Field(sa_column=Column(Integer, nullable=False)) net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False)) + is_invalidated: bool = Field(default=False, nullable=False) + invalidated_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True), nullable=True) + ) + replaced_by_trade_id: int | None = Field( + default=None, foreign_key="trades.id", nullable=True + ) + notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True)) cycle_id: int | None = Field( default=None, foreign_key="cycles.id", nullable=True, index=True )