Files
trading-journal/backend/trading_journal/crud.py
Tianyu Liu 1d215c8032
Some checks failed
Backend CI / unit-test (push) Failing after 24s
wip crud
2025-09-14 17:03:39 +02:00

64 lines
2.1 KiB
Python

from typing import Mapping
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session
from trading_journal import models
def _coerce_enum(enum_cls, value, field_name: str):
if value is None:
raise ValueError(f"{field_name} is required")
# already an enum member
if isinstance(value, enum_cls):
return value
# strict string match: must match exactly enum name or enum value (case-sensitive)
if isinstance(value, str):
for m in enum_cls:
if m.name == value or str(m.value) == value:
return m
allowed = [m.name for m in enum_cls]
raise ValueError(f"Invalid {field_name!s}: {value!r}. Allowed: {allowed}")
# Trades
def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
if hasattr(trade_data, "dict"):
data = trade_data.dict(exclude_unset=True)
else:
data = dict(trade_data)
allowed = {c.name for c in models.Trades.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed}
if "trade_type" not in payload:
raise ValueError("trade_type is required")
payload["trade_type"] = _coerce_enum(
models.TradeType, payload["trade_type"], "trade_type"
)
if "trade_strategy" not in payload:
raise ValueError("trade_strategy is required")
payload["trade_strategy"] = _coerce_enum(
models.TradeStrategy, payload["trade_strategy"], "trade_strategy"
)
cycle_id = payload.get("cycle_id")
user_id = payload.get("user_id")
if cycle_id is not None:
cycle = session.get(models.Cycles, cycle_id)
if cycle is None:
pass # TODO: create a cycle with basic info here
else:
if cycle.user_id != user_id:
raise ValueError("cycle.user_id does not match trade.user_id")
else:
raise ValueError("trade must have a cycle_id.")
t = models.Trades(**payload)
session.add(t)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_trade integrity error") from e
session.refresh(t)
return t