Add invalidate not yet with tests
All checks were successful
Backend CI / unit-test (push) Successful in 27s

This commit is contained in:
2025-09-17 16:36:56 +02:00
parent a0898fa29e
commit d1f1b3e66c
3 changed files with 59 additions and 0 deletions

View File

@@ -294,6 +294,50 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session):
assert trade.user_id == user_id
def test_get_trades_by_user_id(session: Session):
user_id = make_user(session)
cycle_id = make_cycle(session, user_id)
trade_data_1 = {
"user_id": user_id,
"friendly_name": "Trade One",
"symbol": "AAPL",
"underlying_currency": models.UnderlyingCurrency.USD,
"trade_type": models.TradeType.LONG_SPOT,
"trade_strategy": models.TradeStrategy.SPOT,
"trade_date": datetime.now().date(),
"trade_time_utc": datetime.now(),
"quantity": 10,
"price_cents": 15000,
"gross_cash_flow_cents": -150000,
"commission_cents": 500,
"net_cash_flow_cents": -150500,
"cycle_id": cycle_id,
}
trade_data_2 = {
"user_id": user_id,
"friendly_name": "Trade Two",
"symbol": "GOOGL",
"underlying_currency": models.UnderlyingCurrency.USD,
"trade_type": models.TradeType.SHORT_SPOT,
"trade_strategy": models.TradeStrategy.SPOT,
"trade_date": datetime.now().date(),
"trade_time_utc": datetime.now(),
"quantity": 5,
"price_cents": 280000,
"gross_cash_flow_cents": 1400000,
"commission_cents": 700,
"net_cash_flow_cents": 1399300,
"cycle_id": cycle_id,
}
make_trade_by_trade_data(session, trade_data_1)
make_trade_by_trade_data(session, trade_data_2)
trades = crud.get_trades_by_user_id(session, user_id)
assert len(trades) == 2
friendly_names = {trade.friendly_name for trade in trades}
assert friendly_names == {"Trade One", "Trade Two"}
def test_create_cycle(session: Session):
user_id = make_user(session)
cycle_data = {

View File

@@ -115,6 +115,13 @@ def get_trade_by_user_id_and_friendly_name(
return session.exec(statement).first()
def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades]:
statement = select(models.Trades).where(
models.Trades.user_id == user_id,
)
return session.exec(statement).all()
# Cycles
def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
if hasattr(cycle_data, "dict"):

View File

@@ -94,6 +94,14 @@ class Trades(SQLModel, table=True):
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
is_invalidated: bool = Field(default=False, nullable=False)
invalidated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
)
replaced_by_trade_id: int | None = Field(
default=None, foreign_key="trades.id", nullable=True
)
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
cycle_id: int | None = Field(
default=None, foreign_key="cycles.id", nullable=True, index=True
)