feature/db #2
@@ -294,6 +294,50 @@ def test_get_trade_by_user_id_and_friendly_name(session: Session):
|
||||
assert trade.user_id == user_id
|
||||
|
||||
|
||||
def test_get_trades_by_user_id(session: Session):
|
||||
user_id = make_user(session)
|
||||
cycle_id = make_cycle(session, user_id)
|
||||
trade_data_1 = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Trade One",
|
||||
"symbol": "AAPL",
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.LONG_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_date": datetime.now().date(),
|
||||
"trade_time_utc": datetime.now(),
|
||||
"quantity": 10,
|
||||
"price_cents": 15000,
|
||||
"gross_cash_flow_cents": -150000,
|
||||
"commission_cents": 500,
|
||||
"net_cash_flow_cents": -150500,
|
||||
"cycle_id": cycle_id,
|
||||
}
|
||||
trade_data_2 = {
|
||||
"user_id": user_id,
|
||||
"friendly_name": "Trade Two",
|
||||
"symbol": "GOOGL",
|
||||
"underlying_currency": models.UnderlyingCurrency.USD,
|
||||
"trade_type": models.TradeType.SHORT_SPOT,
|
||||
"trade_strategy": models.TradeStrategy.SPOT,
|
||||
"trade_date": datetime.now().date(),
|
||||
"trade_time_utc": datetime.now(),
|
||||
"quantity": 5,
|
||||
"price_cents": 280000,
|
||||
"gross_cash_flow_cents": 1400000,
|
||||
"commission_cents": 700,
|
||||
"net_cash_flow_cents": 1399300,
|
||||
"cycle_id": cycle_id,
|
||||
}
|
||||
make_trade_by_trade_data(session, trade_data_1)
|
||||
make_trade_by_trade_data(session, trade_data_2)
|
||||
|
||||
trades = crud.get_trades_by_user_id(session, user_id)
|
||||
assert len(trades) == 2
|
||||
friendly_names = {trade.friendly_name for trade in trades}
|
||||
assert friendly_names == {"Trade One", "Trade Two"}
|
||||
|
||||
|
||||
def test_create_cycle(session: Session):
|
||||
user_id = make_user(session)
|
||||
cycle_data = {
|
||||
|
||||
@@ -115,6 +115,13 @@ def get_trade_by_user_id_and_friendly_name(
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def get_trades_by_user_id(session: Session, user_id: int) -> list[models.Trades]:
|
||||
statement = select(models.Trades).where(
|
||||
models.Trades.user_id == user_id,
|
||||
)
|
||||
return session.exec(statement).all()
|
||||
|
||||
|
||||
# Cycles
|
||||
def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
|
||||
if hasattr(cycle_data, "dict"):
|
||||
|
||||
@@ -94,6 +94,14 @@ class Trades(SQLModel, table=True):
|
||||
gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
|
||||
is_invalidated: bool = Field(default=False, nullable=False)
|
||||
invalidated_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
replaced_by_trade_id: int | None = Field(
|
||||
default=None, foreign_key="trades.id", nullable=True
|
||||
)
|
||||
notes: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
cycle_id: int | None = Field(
|
||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user