This commit is contained in:
76
backend/tests/test_crud.py
Normal file
76
backend/tests/test_crud.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
from sqlmodel import Session, SQLModel
|
||||||
|
|
||||||
|
from trading_journal import crud, models
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine() -> Engine:
|
||||||
|
e = create_engine(
|
||||||
|
"sqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
SQLModel.metadata.create_all(e)
|
||||||
|
return e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session(engine: Engine) -> Generator[Session, None, None]:
|
||||||
|
with Session(engine) as s:
|
||||||
|
yield s
|
||||||
|
|
||||||
|
|
||||||
|
def make_user(session: Session, username: str = "testuser") -> int:
|
||||||
|
user = models.Users(username=username, password_hash="hashedpassword")
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(user)
|
||||||
|
return user.id
|
||||||
|
|
||||||
|
|
||||||
|
def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int:
|
||||||
|
cycle = models.Cycles(
|
||||||
|
user_id=user_id,
|
||||||
|
friendly_name=friendly_name,
|
||||||
|
symbol="AAPL",
|
||||||
|
underlying_currency="USD",
|
||||||
|
status="open",
|
||||||
|
start_date=datetime.now().date(),
|
||||||
|
)
|
||||||
|
session.add(cycle)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(cycle)
|
||||||
|
return cycle.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_trade_success(session: Session):
|
||||||
|
user_id = make_user(session)
|
||||||
|
cycle_id = make_cycle(session, user_id)
|
||||||
|
|
||||||
|
trade_data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"friendly_name": "Test Trade",
|
||||||
|
"symbol": "AAPL",
|
||||||
|
"underlying_currency": "USD",
|
||||||
|
"trade_type": "LONG_SPOT",
|
||||||
|
"trade_strategy": "SPOT",
|
||||||
|
"trade_time_utc": datetime.now(),
|
||||||
|
"quantity": 10,
|
||||||
|
"price_cents": 15000,
|
||||||
|
"gross_cash_flow_cents": -150000,
|
||||||
|
"commission_cents": 500,
|
||||||
|
"net_cash_flow_cents": -150500,
|
||||||
|
"cycle_id": cycle_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
trade = crud.create_trade(session, trade_data)
|
||||||
|
assert trade.id is not None
|
||||||
|
assert trade.user_id == user_id
|
||||||
|
assert trade.cycle_id == cycle_id
|
||||||
@@ -36,8 +36,8 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
"symbol": ("TEXT", 1, 0),
|
"symbol": ("TEXT", 1, 0),
|
||||||
"underlying_currency": ("TEXT", 1, 0),
|
"underlying_currency": ("TEXT", 1, 0),
|
||||||
"status": ("TEXT", 1, 0),
|
"status": ("TEXT", 1, 0),
|
||||||
"funding_source": ("TEXT", 1, 0),
|
"funding_source": ("TEXT", 0, 0),
|
||||||
"capital_exposure_cents": ("INTEGER", 1, 0),
|
"capital_exposure_cents": ("INTEGER", 0, 0),
|
||||||
"loan_amount_cents": ("INTEGER", 0, 0),
|
"loan_amount_cents": ("INTEGER", 0, 0),
|
||||||
"loan_interest_rate_bps": ("INTEGER", 0, 0),
|
"loan_interest_rate_bps": ("INTEGER", 0, 0),
|
||||||
"start_date": ("DATE", 1, 0),
|
"start_date": ("DATE", 1, 0),
|
||||||
|
|||||||
@@ -1 +1,63 @@
|
|||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlmodel import Session
|
||||||
|
|
||||||
|
from trading_journal import models
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_enum(enum_cls, value, field_name: str):
|
||||||
|
if value is None:
|
||||||
|
raise ValueError(f"{field_name} is required")
|
||||||
|
# already an enum member
|
||||||
|
if isinstance(value, enum_cls):
|
||||||
|
return value
|
||||||
|
# strict string match: must match exactly enum name or enum value (case-sensitive)
|
||||||
|
if isinstance(value, str):
|
||||||
|
for m in enum_cls:
|
||||||
|
if m.name == value or str(m.value) == value:
|
||||||
|
return m
|
||||||
|
allowed = [m.name for m in enum_cls]
|
||||||
|
raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}")
|
||||||
|
|
||||||
|
|
||||||
|
# Trades
|
||||||
|
def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
|
||||||
|
if hasattr(trade_data, "dict"):
|
||||||
|
data = trade_data.dict(exclude_unset=True)
|
||||||
|
else:
|
||||||
|
data = dict(trade_data)
|
||||||
|
allowed = {c.name for c in models.Trades.__table__.columns}
|
||||||
|
payload = {k: v for k, v in data.items() if k in allowed}
|
||||||
|
if "trade_type" not in payload:
|
||||||
|
raise ValueError("trade_type is required")
|
||||||
|
payload["trade_type"] = _coerce_enum(
|
||||||
|
models.TradeType, payload["trade_type"], "trade_type"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "trade_strategy" not in payload:
|
||||||
|
raise ValueError("trade_strategy is required")
|
||||||
|
payload["trade_strategy"] = _coerce_enum(
|
||||||
|
models.TradeStrategy, payload["trade_strategy"], "trade_strategy"
|
||||||
|
)
|
||||||
|
cycle_id = payload.get("cycle_id")
|
||||||
|
user_id = payload.get("user_id")
|
||||||
|
|
||||||
|
if cycle_id is not None:
|
||||||
|
cycle = session.get(models.Cycles, cycle_id)
|
||||||
|
if cycle is None:
|
||||||
|
pass # TODO: create a cycle with basic info here
|
||||||
|
else:
|
||||||
|
if cycle.user_id != user_id:
|
||||||
|
raise ValueError("cycle.user_id does not match trade.user_id")
|
||||||
|
else:
|
||||||
|
raise ValueError("trade must have a cycle_id.")
|
||||||
|
t = models.Trades(**payload)
|
||||||
|
session.add(t)
|
||||||
|
try:
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError("create_trade integrity error") from e
|
||||||
|
session.refresh(t)
|
||||||
|
return t
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import date, datetime # noqa: TC003
|
from datetime import date, datetime # noqa: TC003
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -75,7 +73,7 @@ class Trades(SQLModel, table=True):
|
|||||||
cycle_id: int | None = Field(
|
cycle_id: int | None = Field(
|
||||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
default=None, foreign_key="cycles.id", nullable=True, index=True
|
||||||
)
|
)
|
||||||
cycle: Cycles | None = Relationship(back_populates="trades")
|
cycle: "Cycles" = Relationship(back_populates="trades")
|
||||||
|
|
||||||
|
|
||||||
class Cycles(SQLModel, table=True):
|
class Cycles(SQLModel, table=True):
|
||||||
@@ -94,13 +92,13 @@ class Cycles(SQLModel, table=True):
|
|||||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
|
underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
||||||
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False))
|
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
||||||
capital_exposure_cents: int
|
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
||||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||||
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
||||||
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||||
trades: list[Trades] = Relationship(back_populates="cycle")
|
trades: list["Trades"] = Relationship(back_populates="cycle")
|
||||||
|
|
||||||
|
|
||||||
class Users(SQLModel, table=True):
|
class Users(SQLModel, table=True):
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import date, datetime # noqa: TC003
|
from datetime import date, datetime # noqa: TC003
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -75,7 +73,7 @@ class Trades(SQLModel, table=True):
|
|||||||
cycle_id: int | None = Field(
|
cycle_id: int | None = Field(
|
||||||
default=None, foreign_key="cycles.id", nullable=True, index=True
|
default=None, foreign_key="cycles.id", nullable=True, index=True
|
||||||
)
|
)
|
||||||
cycle: Cycles | None = Relationship(back_populates="trades")
|
cycle: "Cycles" = Relationship(back_populates="trades")
|
||||||
|
|
||||||
|
|
||||||
class Cycles(SQLModel, table=True):
|
class Cycles(SQLModel, table=True):
|
||||||
@@ -94,13 +92,13 @@ class Cycles(SQLModel, table=True):
|
|||||||
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
symbol: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
|
underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
|
||||||
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
|
||||||
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False))
|
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
|
||||||
capital_exposure_cents: int
|
capital_exposure_cents: int | None = Field(default=None, nullable=True)
|
||||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||||
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
||||||
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
start_date: date = Field(sa_column=Column(Date, nullable=False))
|
||||||
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
|
||||||
trades: list[Trades] = Relationship(back_populates="cycle")
|
trades: list["Trades"] = Relationship(back_populates="cycle")
|
||||||
|
|
||||||
|
|
||||||
class Users(SQLModel, table=True):
|
class Users(SQLModel, table=True):
|
||||||
|
|||||||
Reference in New Issue
Block a user