feature/db #2

Merged
tliu93 merged 16 commits from feature/db into main 2025-09-18 14:28:18 +02:00
5 changed files with 148 additions and 14 deletions
Showing only changes of commit 1d215c8032 - Show all commits

View File

@@ -0,0 +1,76 @@
from collections.abc import Generator
from datetime import datetime
import pytest
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel
from trading_journal import crud, models
@pytest.fixture
def engine() -> Engine:
e = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(e)
return e
@pytest.fixture
def session(engine: Engine) -> Generator[Session, None, None]:
with Session(engine) as s:
yield s
def make_user(session: Session, username: str = "testuser") -> int:
user = models.Users(username=username, password_hash="hashedpassword")
session.add(user)
session.commit()
session.refresh(user)
return user.id
def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int:
cycle = models.Cycles(
user_id=user_id,
friendly_name=friendly_name,
symbol="AAPL",
underlying_currency="USD",
status="open",
start_date=datetime.now().date(),
)
session.add(cycle)
session.commit()
session.refresh(cycle)
return cycle.id
def test_create_trade_success(session: Session):
user_id = make_user(session)
cycle_id = make_cycle(session, user_id)
trade_data = {
"user_id": user_id,
"friendly_name": "Test Trade",
"symbol": "AAPL",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "SPOT",
"trade_time_utc": datetime.now(),
"quantity": 10,
"price_cents": 15000,
"gross_cash_flow_cents": -150000,
"commission_cents": 500,
"net_cash_flow_cents": -150500,
"cycle_id": cycle_id,
}
trade = crud.create_trade(session, trade_data)
assert trade.id is not None
assert trade.user_id == user_id
assert trade.cycle_id == cycle_id

View File

@@ -36,8 +36,8 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
"symbol": ("TEXT", 1, 0), "symbol": ("TEXT", 1, 0),
"underlying_currency": ("TEXT", 1, 0), "underlying_currency": ("TEXT", 1, 0),
"status": ("TEXT", 1, 0), "status": ("TEXT", 1, 0),
"funding_source": ("TEXT", 1, 0), "funding_source": ("TEXT", 0, 0),
"capital_exposure_cents": ("INTEGER", 1, 0), "capital_exposure_cents": ("INTEGER", 0, 0),
"loan_amount_cents": ("INTEGER", 0, 0), "loan_amount_cents": ("INTEGER", 0, 0),
"loan_interest_rate_bps": ("INTEGER", 0, 0), "loan_interest_rate_bps": ("INTEGER", 0, 0),
"start_date": ("DATE", 1, 0), "start_date": ("DATE", 1, 0),

View File

@@ -1 +1,63 @@
from typing import Mapping
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session
from trading_journal import models
def _coerce_enum(enum_cls, value, field_name: str):
if value is None:
raise ValueError(f"{field_name} is required")
# already an enum member
if isinstance(value, enum_cls):
return value
# strict string match: must match exactly enum name or enum value (case-sensitive)
if isinstance(value, str):
for m in enum_cls:
if m.name == value or str(m.value) == value:
return m
allowed = [m.name for m in enum_cls]
raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}")
# Trades
def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
if hasattr(trade_data, "dict"):
data = trade_data.dict(exclude_unset=True)
else:
data = dict(trade_data)
allowed = {c.name for c in models.Trades.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed}
if "trade_type" not in payload:
raise ValueError("trade_type is required")
payload["trade_type"] = _coerce_enum(
models.TradeType, payload["trade_type"], "trade_type"
)
if "trade_strategy" not in payload:
raise ValueError("trade_strategy is required")
payload["trade_strategy"] = _coerce_enum(
models.TradeStrategy, payload["trade_strategy"], "trade_strategy"
)
cycle_id = payload.get("cycle_id")
user_id = payload.get("user_id")
if cycle_id is not None:
cycle = session.get(models.Cycles, cycle_id)
if cycle is None:
pass # TODO: create a cycle with basic info here
else:
if cycle.user_id != user_id:
raise ValueError("cycle.user_id does not match trade.user_id")
else:
raise ValueError("trade must have a cycle_id.")
t = models.Trades(**payload)
session.add(t)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_trade integrity error") from e
session.refresh(t)
return t

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import date, datetime # noqa: TC003 from datetime import date, datetime # noqa: TC003
from enum import Enum from enum import Enum
@@ -75,7 +73,7 @@ class Trades(SQLModel, table=True):
cycle_id: int | None = Field( cycle_id: int | None = Field(
default=None, foreign_key="cycles.id", nullable=True, index=True default=None, foreign_key="cycles.id", nullable=True, index=True
) )
cycle: Cycles | None = Relationship(back_populates="trades") cycle: "Cycles" = Relationship(back_populates="trades")
class Cycles(SQLModel, table=True): class Cycles(SQLModel, table=True):
@@ -94,13 +92,13 @@ class Cycles(SQLModel, table=True):
symbol: str = Field(sa_column=Column(Text, nullable=False)) symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
capital_exposure_cents: int capital_exposure_cents: int | None = Field(default=None, nullable=True)
loan_amount_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_bps: int | None = Field(default=None, nullable=True) loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
start_date: date = Field(sa_column=Column(Date, nullable=False)) start_date: date = Field(sa_column=Column(Date, nullable=False))
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
trades: list[Trades] = Relationship(back_populates="cycle") trades: list["Trades"] = Relationship(back_populates="cycle")
class Users(SQLModel, table=True): class Users(SQLModel, table=True):

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import date, datetime # noqa: TC003 from datetime import date, datetime # noqa: TC003
from enum import Enum from enum import Enum
@@ -75,7 +73,7 @@ class Trades(SQLModel, table=True):
cycle_id: int | None = Field( cycle_id: int | None = Field(
default=None, foreign_key="cycles.id", nullable=True, index=True default=None, foreign_key="cycles.id", nullable=True, index=True
) )
cycle: Cycles | None = Relationship(back_populates="trades") cycle: "Cycles" = Relationship(back_populates="trades")
class Cycles(SQLModel, table=True): class Cycles(SQLModel, table=True):
@@ -94,13 +92,13 @@ class Cycles(SQLModel, table=True):
symbol: str = Field(sa_column=Column(Text, nullable=False)) symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
capital_exposure_cents: int capital_exposure_cents: int | None = Field(default=None, nullable=True)
loan_amount_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_bps: int | None = Field(default=None, nullable=True) loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
start_date: date = Field(sa_column=Column(Date, nullable=False)) start_date: date = Field(sa_column=Column(Date, nullable=False))
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True)) end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
trades: list[Trades] = Relationship(back_populates="cycle") trades: list["Trades"] = Relationship(back_populates="cycle")
class Users(SQLModel, table=True): class Users(SQLModel, table=True):