continue on crud
Some checks failed
Backend CI / unit-test (push) Failing after 27s

This commit is contained in:
2025-09-14 21:01:12 +02:00
parent 5753ad3767
commit 2c22f20b48
4 changed files with 435 additions and 22 deletions

View File

@@ -9,6 +9,8 @@ from sqlmodel import Session, SQLModel
from trading_journal import crud, models from trading_journal import crud, models
# TODO: If needed, add failing flow tests, but now only add happy flow.
@pytest.fixture @pytest.fixture
def engine() -> Generator[Engine, None, None]: def engine() -> Generator[Engine, None, None]:
@@ -45,7 +47,7 @@ def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int:
user_id=user_id, user_id=user_id,
friendly_name=friendly_name, friendly_name=friendly_name,
symbol="AAPL", symbol="AAPL",
underlying_currency="USD", underlying_currency=models.UnderlyingCurrency.USD,
status=models.CycleStatus.OPEN, status=models.CycleStatus.OPEN,
start_date=datetime.now().date(), start_date=datetime.now().date(),
) )
@@ -55,7 +57,40 @@ def make_cycle(session, user_id: int, friendly_name: str = "Test Cycle") -> int:
return cycle.id return cycle.id
def test_create_trade_success(session: Session): def make_trade(
session, user_id: int, cycle_id: int, friendly_name: str = "Test Trade"
) -> int:
trade = models.Trades(
user_id=user_id,
friendly_name=friendly_name,
symbol="AAPL",
underlying_currency=models.UnderlyingCurrency.USD,
trade_type=models.TradeType.LONG_SPOT,
trade_strategy=models.TradeStrategy.SPOT,
trade_date=datetime.now().date(),
trade_time_utc=datetime.now(),
quantity=10,
price_cents=15000,
gross_cash_flow_cents=-150000,
commission_cents=500,
net_cash_flow_cents=-150500,
cycle_id=cycle_id,
)
session.add(trade)
session.commit()
session.refresh(trade)
return trade.id
def make_trade_by_trade_data(session, trade_data: dict) -> int:
trade = models.Trades(**trade_data)
session.add(trade)
session.commit()
session.refresh(trade)
return trade.id
def test_create_trade_success_with_cycle(session: Session):
user_id = make_user(session) user_id = make_user(session)
cycle_id = make_cycle(session, user_id) cycle_id = make_cycle(session, user_id)
@@ -79,3 +114,228 @@ def test_create_trade_success(session: Session):
assert trade.id is not None assert trade.id is not None
assert trade.user_id == user_id assert trade.user_id == user_id
assert trade.cycle_id == cycle_id assert trade.cycle_id == cycle_id
session.refresh(trade)
actual_trade = session.get(models.Trades, trade.id)
assert actual_trade is not None
assert actual_trade.friendly_name == trade_data["friendly_name"]
assert actual_trade.symbol == trade_data["symbol"]
assert actual_trade.underlying_currency == trade_data["underlying_currency"]
assert actual_trade.trade_type == trade_data["trade_type"]
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
assert actual_trade.quantity == trade_data["quantity"]
assert actual_trade.price_cents == trade_data["price_cents"]
assert actual_trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
assert actual_trade.commission_cents == trade_data["commission_cents"]
assert actual_trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"]
assert actual_trade.cycle_id == trade_data["cycle_id"]
def test_create_trade_with_auto_created_cycle(session: Session):
user_id = make_user(session)
trade_data = {
"user_id": user_id,
"friendly_name": "Test Trade with Auto Cycle",
"symbol": "AAPL",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "SPOT",
"trade_time_utc": datetime.now(),
"quantity": 5,
"price_cents": 15500,
}
trade = crud.create_trade(session, trade_data)
assert trade.id is not None
assert trade.user_id == user_id
assert trade.cycle_id is not None
session.refresh(trade)
actual_trade = session.get(models.Trades, trade.id)
assert actual_trade is not None
assert actual_trade.friendly_name == trade_data["friendly_name"]
assert actual_trade.symbol == trade_data["symbol"]
assert actual_trade.underlying_currency == trade_data["underlying_currency"]
assert actual_trade.trade_type == trade_data["trade_type"]
assert actual_trade.trade_strategy == trade_data["trade_strategy"]
assert actual_trade.quantity == trade_data["quantity"]
assert actual_trade.price_cents == trade_data["price_cents"]
assert actual_trade.cycle_id == trade.cycle_id
# Verify the auto-created cycle
auto_cycle = session.get(models.Cycles, trade.cycle_id)
assert auto_cycle is not None
assert auto_cycle.user_id == user_id
assert auto_cycle.symbol == trade_data["symbol"]
assert auto_cycle.underlying_currency == trade_data["underlying_currency"]
assert auto_cycle.status == models.CycleStatus.OPEN
assert auto_cycle.friendly_name.startswith("Auto-created Cycle by trade")
def test_create_trade_missing_required_fields(session: Session):
user_id = make_user(session)
base_trade_data = {
"user_id": user_id,
"friendly_name": "Incomplete Trade",
"symbol": "AAPL",
"underlying_currency": "USD",
"trade_type": "LONG_SPOT",
"trade_strategy": "SPOT",
"trade_time_utc": datetime.now(),
"quantity": 10,
"price_cents": 15000,
}
# Missing symbol
trade_data = base_trade_data.copy()
trade_data.pop("symbol", None)
with pytest.raises(ValueError) as excinfo:
crud.create_trade(session, trade_data)
assert "symbol is required" in str(excinfo.value)
# Missing underlying_currency
trade_data = base_trade_data.copy()
trade_data.pop("underlying_currency", None)
with pytest.raises(ValueError) as excinfo:
crud.create_trade(session, trade_data)
assert "underlying_currency is required" in str(excinfo.value)
# Missing trade_type
trade_data = base_trade_data.copy()
trade_data.pop("trade_type", None)
with pytest.raises(ValueError) as excinfo:
crud.create_trade(session, trade_data)
assert "trade_type is required" in str(excinfo.value)
# Missing trade_strategy
trade_data = base_trade_data.copy()
trade_data.pop("trade_strategy", None)
with pytest.raises(ValueError) as excinfo:
crud.create_trade(session, trade_data)
assert "trade_strategy is required" in str(excinfo.value)
# Missing quantity
trade_data = base_trade_data.copy()
trade_data.pop("quantity", None)
with pytest.raises(ValueError) as excinfo:
crud.create_trade(session, trade_data)
assert "quantity is required" in str(excinfo.value)
# Missing price_cents
trade_data = base_trade_data.copy()
trade_data.pop("price_cents", None)
with pytest.raises(ValueError) as excinfo:
crud.create_trade(session, trade_data)
assert "price_cents is required" in str(excinfo.value)
def test_get_trade_by_id(session: Session):
user_id = make_user(session)
cycle_id = make_cycle(session, user_id)
trade_data = {
"user_id": user_id,
"friendly_name": "Test Trade for Get",
"symbol": "AAPL",
"underlying_currency": models.UnderlyingCurrency.USD,
"trade_type": models.TradeType.LONG_SPOT,
"trade_strategy": models.TradeStrategy.SPOT,
"trade_date": datetime.now().date(),
"trade_time_utc": datetime.now(),
"quantity": 10,
"price_cents": 15000,
"gross_cash_flow_cents": -150000,
"commission_cents": 500,
"net_cash_flow_cents": -150500,
"cycle_id": cycle_id,
}
trade_id = make_trade_by_trade_data(session, trade_data)
trade = crud.get_trade_by_id(session, trade_id)
assert trade is not None
assert trade.id == trade_id
assert trade.friendly_name == trade_data["friendly_name"]
assert trade.symbol == trade_data["symbol"]
assert trade.underlying_currency == trade_data["underlying_currency"]
assert trade.trade_type == trade_data["trade_type"]
assert trade.trade_strategy == trade_data["trade_strategy"]
assert trade.quantity == trade_data["quantity"]
assert trade.price_cents == trade_data["price_cents"]
assert trade.gross_cash_flow_cents == trade_data["gross_cash_flow_cents"]
assert trade.commission_cents == trade_data["commission_cents"]
assert trade.net_cash_flow_cents == trade_data["net_cash_flow_cents"]
assert trade.cycle_id == trade_data["cycle_id"]
assert trade.trade_date == trade_data["trade_date"]
def test_get_trade_by_user_id_and_friendly_name(session: Session):
user_id = make_user(session)
cycle_id = make_cycle(session, user_id)
friendly_name = "Unique Trade Name"
trade_data = {
"user_id": user_id,
"friendly_name": friendly_name,
"symbol": "AAPL",
"underlying_currency": models.UnderlyingCurrency.USD,
"trade_type": models.TradeType.LONG_SPOT,
"trade_strategy": models.TradeStrategy.SPOT,
"trade_date": datetime.now().date(),
"trade_time_utc": datetime.now(),
"quantity": 10,
"price_cents": 15000,
"gross_cash_flow_cents": -150000,
"commission_cents": 500,
"net_cash_flow_cents": -150500,
"cycle_id": cycle_id,
}
make_trade_by_trade_data(session, trade_data)
trade = crud.get_trade_by_user_id_and_friendly_name(session, user_id, friendly_name)
assert trade is not None
assert trade.friendly_name == friendly_name
assert trade.user_id == user_id
def test_create_cycle(session: Session):
user_id = make_user(session)
cycle_data = {
"user_id": user_id,
"friendly_name": "My First Cycle",
"symbol": "GOOGL",
"underlying_currency": "USD",
"status": models.CycleStatus.OPEN,
"start_date": datetime.now().date(),
}
cycle = crud.create_cycle(session, cycle_data)
assert cycle.id is not None
assert cycle.user_id == user_id
assert cycle.friendly_name == cycle_data["friendly_name"]
assert cycle.symbol == cycle_data["symbol"]
assert cycle.underlying_currency == cycle_data["underlying_currency"]
assert cycle.status == cycle_data["status"]
assert cycle.start_date == cycle_data["start_date"]
session.refresh(cycle)
actual_cycle = session.get(models.Cycles, cycle.id)
assert actual_cycle is not None
assert actual_cycle.friendly_name == cycle_data["friendly_name"]
assert actual_cycle.symbol == cycle_data["symbol"]
assert actual_cycle.underlying_currency == cycle_data["underlying_currency"]
assert actual_cycle.status == cycle_data["status"]
assert actual_cycle.start_date == cycle_data["start_date"]
def test_create_user(session: Session):
user_data = {
"username": "newuser",
"password_hash": "newhashedpassword",
}
user = crud.create_user(session, user_data)
assert user.id is not None
assert user.username == user_data["username"]
assert user.password_hash == user_data["password_hash"]
session.refresh(user)
actual_user = session.get(models.Users, user.id)
assert actual_user is not None
assert actual_user.username == user_data["username"]
assert actual_user.password_hash == user_data["password_hash"]

View File

@@ -1,12 +1,13 @@
from datetime import datetime, timezone
from typing import Mapping from typing import Mapping
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel import Session from sqlmodel import Session, select
from trading_journal import models from trading_journal import models
def _coerce_enum(enum_cls, value, field_name: str): def _check_enum(enum_cls, value, field_name: str):
if value is None: if value is None:
raise ValueError(f"{field_name} is required") raise ValueError(f"{field_name} is required")
# already an enum member # already an enum member
@@ -29,29 +30,66 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
data = dict(trade_data) data = dict(trade_data)
allowed = {c.name for c in models.Trades.__table__.columns} allowed = {c.name for c in models.Trades.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed} payload = {k: v for k, v in data.items() if k in allowed}
if "symbol" not in payload:
raise ValueError("symbol is required")
if "underlying_currency" not in payload:
raise ValueError("underlying_currency is required")
payload["underlying_currency"] = _check_enum(
models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency"
)
if "trade_type" not in payload: if "trade_type" not in payload:
raise ValueError("trade_type is required") raise ValueError("trade_type is required")
payload["trade_type"] = _coerce_enum( payload["trade_type"] = _check_enum(
models.TradeType, payload["trade_type"], "trade_type" models.TradeType, payload["trade_type"], "trade_type"
) )
if "trade_strategy" not in payload: if "trade_strategy" not in payload:
raise ValueError("trade_strategy is required") raise ValueError("trade_strategy is required")
payload["trade_strategy"] = _coerce_enum( payload["trade_strategy"] = _check_enum(
models.TradeStrategy, payload["trade_strategy"], "trade_strategy" models.TradeStrategy, payload["trade_strategy"], "trade_strategy"
) )
now = datetime.now(timezone.utc)
payload.pop("trade_time_utc", None)
payload["trade_time_utc"] = now
if "trade_date" not in payload or payload.get("trade_date") is None:
payload["trade_date"] = payload["trade_time_utc"].date()
cycle_id = payload.get("cycle_id") cycle_id = payload.get("cycle_id")
user_id = payload.get("user_id") user_id = payload.get("user_id")
if "quantity" not in payload:
raise ValueError("quantity is required")
if "price_cents" not in payload:
raise ValueError("price_cents is required")
if "commission_cents" not in payload:
payload["commission_cents"] = 0
quantity: int = payload["quantity"]
price_cents: int = payload["price_cents"]
commission_cents: int = payload["commission_cents"]
if "gross_cash_flow_cents" not in payload:
payload["gross_cash_flow_cents"] = -quantity * price_cents
if "net_cash_flow_cents" not in payload:
payload["net_cash_flow_cents"] = (
payload["gross_cash_flow_cents"] - commission_cents
)
if cycle_id is None:
cycle_id = create_cycle(
session,
{
"user_id": user_id,
"symbol": payload["symbol"],
"underlying_currency": payload["underlying_currency"],
"friendly_name": "Auto-created Cycle by trade "
+ payload.get("friendly_name", ""),
"status": models.CycleStatus.OPEN,
"start_date": payload["trade_date"],
},
).id
payload["cycle_id"] = cycle_id
if cycle_id is not None: if cycle_id is not None:
cycle = session.get(models.Cycles, cycle_id) cycle = session.get(models.Cycles, cycle_id)
if cycle is None: if cycle is None:
pass # TODO: create a cycle with basic info here raise ValueError("cycle_id does not exist")
else: else:
if cycle.user_id != user_id: if cycle.user_id != user_id:
raise ValueError("cycle.user_id does not match trade.user_id") raise ValueError("cycle.user_id does not match trade.user_id")
else:
raise ValueError("trade must have a cycle_id.")
t = models.Trades(**payload) t = models.Trades(**payload)
session.add(t) session.add(t)
try: try:
@@ -61,3 +99,75 @@ def create_trade(session: Session, trade_data: Mapping) -> models.Trades:
raise ValueError("create_trade integrity error") from e raise ValueError("create_trade integrity error") from e
session.refresh(t) session.refresh(t)
return t return t
def get_trade_by_id(session: Session, trade_id: int) -> models.Trades | None:
return session.get(models.Trades, trade_id)
def get_trade_by_user_id_and_friendly_name(
session: Session, user_id: int, friendly_name: str
) -> models.Trades | None:
statement = select(models.Trades).where(
models.Trades.user_id == user_id,
models.Trades.friendly_name == friendly_name,
)
return session.exec(statement).first()
# Cycles
def create_cycle(session: Session, cycle_data: Mapping) -> models.Cycles:
if hasattr(cycle_data, "dict"):
data = cycle_data.dict(exclude_unset=True)
else:
data = dict(cycle_data)
allowed = {c.name for c in models.Cycles.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed}
if "user_id" not in payload:
raise ValueError("user_id is required")
if "symbol" not in payload:
raise ValueError("symbol is required")
if "underlying_currency" not in payload:
raise ValueError("underlying_currency is required")
payload["underlying_currency"] = _check_enum(
models.UnderlyingCurrency, payload["underlying_currency"], "underlying_currency"
)
if "status" not in payload:
raise ValueError("status is required")
payload["status"] = _check_enum(models.CycleStatus, payload["status"], "status")
if "start_date" not in payload:
raise ValueError("start_date is required")
c = models.Cycles(**payload)
session.add(c)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_cycle integrity error") from e
session.refresh(c)
return c
# Users
def create_user(session: Session, user_data: Mapping) -> models.Users:
if hasattr(user_data, "dict"):
data = user_data.dict(exclude_unset=True)
else:
data = dict(user_data)
allowed = {c.name for c in models.Users.__table__.columns}
payload = {k: v for k, v in data.items() if k in allowed}
if "username" not in payload:
raise ValueError("username is required")
if "password_hash" not in payload:
raise ValueError("password_hash is required")
u = models.Users(**payload)
session.add(u)
try:
session.flush()
except IntegrityError as e:
session.rollback()
raise ValueError("create_user integrity error") from e
session.refresh(u)
return u

View File

@@ -1,8 +1,17 @@
from datetime import date, datetime # noqa: TC003 from datetime import date, datetime # noqa: TC003
from enum import Enum from enum import Enum
from sqlalchemy import Date, Text, UniqueConstraint from sqlmodel import (
from sqlmodel import Column, DateTime, Field, Relationship, SQLModel Column,
Date,
DateTime,
Field,
Integer,
Relationship,
SQLModel,
Text,
UniqueConstraint,
)
class TradeType(str, Enum): class TradeType(str, Enum):
@@ -36,6 +45,18 @@ class CycleStatus(str, Enum):
CLOSED = "CLOSED" CLOSED = "CLOSED"
class UnderlyingCurrency(str, Enum):
EUR = "EUR"
USD = "USD"
GBP = "GBP"
JPY = "JPY"
AUD = "AUD"
CAD = "CAD"
CHF = "CHF"
NZD = "NZD"
CNY = "CNY"
class FundingSource(str, Enum): class FundingSource(str, Enum):
CASH = "CASH" CASH = "CASH"
MARGIN = "MARGIN" MARGIN = "MARGIN"
@@ -57,19 +78,22 @@ class Trades(SQLModel, table=True):
default=None, sa_column=Column(Text, nullable=True) default=None, sa_column=Column(Text, nullable=True)
) )
symbol: str = Field(sa_column=Column(Text, nullable=False)) symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: UnderlyingCurrency = Field(
sa_column=Column(Text, nullable=False)
)
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
trade_date: date = Field(sa_column=Column(Date, nullable=False))
trade_time_utc: datetime = Field( trade_time_utc: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False) sa_column=Column(DateTime(timezone=True), nullable=False)
) )
expiry_date: date | None = Field(default=None, nullable=True) expiry_date: date | None = Field(default=None, nullable=True)
strike_price_cents: int | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True)
quantity: int quantity: int = Field(sa_column=Column(Integer, nullable=False))
price_cents: int price_cents: int = Field(sa_column=Column(Integer, nullable=False))
gross_cash_flow_cents: int gross_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
commission_cents: int commission_cents: int = Field(sa_column=Column(Integer, nullable=False))
net_cash_flow_cents: int net_cash_flow_cents: int = Field(sa_column=Column(Integer, nullable=False))
cycle_id: int | None = Field( cycle_id: int | None = Field(
default=None, foreign_key="cycles.id", nullable=True, index=True default=None, foreign_key="cycles.id", nullable=True, index=True
) )
@@ -90,7 +114,9 @@ class Cycles(SQLModel, table=True):
default=None, sa_column=Column(Text, nullable=True) default=None, sa_column=Column(Text, nullable=True)
) )
symbol: str = Field(sa_column=Column(Text, nullable=False)) symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: UnderlyingCurrency = Field(
sa_column=Column(Text, nullable=False)
)
status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
capital_exposure_cents: int | None = Field(default=None, nullable=True) capital_exposure_cents: int | None = Field(default=None, nullable=True)

View File

@@ -36,6 +36,18 @@ class CycleStatus(str, Enum):
CLOSED = "CLOSED" CLOSED = "CLOSED"
class UnderlyingCurrency(str, Enum):
EUR = "EUR"
USD = "USD"
GBP = "GBP"
JPY = "JPY"
AUD = "AUD"
CAD = "CAD"
CHF = "CHF"
NZD = "NZD"
CNY = "CNY"
class FundingSource(str, Enum): class FundingSource(str, Enum):
CASH = "CASH" CASH = "CASH"
MARGIN = "MARGIN" MARGIN = "MARGIN"
@@ -57,9 +69,12 @@ class Trades(SQLModel, table=True):
default=None, sa_column=Column(Text, nullable=True) default=None, sa_column=Column(Text, nullable=True)
) )
symbol: str = Field(sa_column=Column(Text, nullable=False)) symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: UnderlyingCurrency = Field(
sa_column=Column(Text, nullable=False)
)
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False)) trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False)) trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
trade_date: date = Field(sa_column=Column(Date, nullable=False))
trade_time_utc: datetime = Field( trade_time_utc: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False) sa_column=Column(DateTime(timezone=True), nullable=False)
) )
@@ -90,7 +105,9 @@ class Cycles(SQLModel, table=True):
default=None, sa_column=Column(Text, nullable=True) default=None, sa_column=Column(Text, nullable=True)
) )
symbol: str = Field(sa_column=Column(Text, nullable=False)) symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False)) underlying_currency: UnderlyingCurrency = Field(
sa_column=Column(Text, nullable=False)
)
status: CycleStatus = Field(sa_column=Column(Text, nullable=False)) status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True)) funding_source: FundingSource = Field(sa_column=Column(Text, nullable=True))
capital_exposure_cents: int | None = Field(default=None, nullable=True) capital_exposure_cents: int | None = Field(default=None, nullable=True)