feature/db #2

Merged
tliu93 merged 16 commits from feature/db into main 2025-09-18 14:28:18 +02:00
7 changed files with 253 additions and 78 deletions
Showing only changes of commit 479d5cd230 - Show all commits

4
backend/.gitignore vendored
View File

@@ -11,3 +11,7 @@ venv.bak/
__pycache__/ __pycache__/
.pytest_cache/ .pytest_cache/
*.db
*.db-shm
*.db-wal

View File

@@ -18,20 +18,59 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
connect_args={"check_same_thread": False}, connect_args={"check_same_thread": False},
poolclass=StaticPool, poolclass=StaticPool,
) )
# ensure target is the LATEST_VERSION we expect for the test
monkeypatch.setattr(db_migration, "LATEST_VERSION", 1) monkeypatch.setattr(db_migration, "LATEST_VERSION", 1)
# run real migrations (will import trading_journal.models_v1 inside _mig_0_1)
final_version = db_migration.run_migrations(engine) final_version = db_migration.run_migrations(engine)
assert final_version == 1 assert final_version == 1
# import snapshot models to validate schema expected_schema = {
from trading_journal import models_v1 "users": {
"id": ("INTEGER", 1, 1),
"username": ("TEXT", 1, 0),
"password_hash": ("TEXT", 1, 0),
"is_active": ("BOOLEAN", 1, 0),
},
"cycles": {
"id": ("INTEGER", 1, 1),
"user_id": ("INTEGER", 1, 0),
"friendly_name": ("TEXT", 0, 0),
"symbol": ("TEXT", 1, 0),
"underlying_currency": ("TEXT", 1, 0),
"status": ("TEXT", 1, 0),
"funding_source": ("TEXT", 1, 0),
"capital_exposure_cents": ("INTEGER", 1, 0),
"loan_amount_cents": ("INTEGER", 0, 0),
"loan_interest_rate_bps": ("INTEGER", 0, 0),
"start_date": ("DATE", 1, 0),
"end_date": ("DATE", 0, 0),
},
"trades": {
"id": ("INTEGER", 1, 1),
"user_id": ("INTEGER", 1, 0),
"friendly_name": ("TEXT", 0, 0),
"symbol": ("TEXT", 1, 0),
"underlying_currency": ("TEXT", 1, 0),
"trade_type": ("TEXT", 1, 0),
"trade_strategy": ("TEXT", 1, 0),
"trade_time_utc": ("DATETIME", 1, 0),
"expiry_date": ("DATE", 0, 0),
"strike_price_cents": ("INTEGER", 0, 0),
"quantity": ("INTEGER", 1, 0),
"price_cents": ("INTEGER", 1, 0),
"gross_cash_flow_cents": ("INTEGER", 1, 0),
"commission_cents": ("INTEGER", 1, 0),
"net_cash_flow_cents": ("INTEGER", 1, 0),
"cycle_id": ("INTEGER", 0, 0),
},
}
expected_tables = { expected_fks = {
"trades": models_v1.Trades.__table__, "trades": [
"cycles": models_v1.Cycles.__table__, {"table": "cycles", "from": "cycle_id", "to": "id"},
{"table": "users", "from": "user_id", "to": "id"},
],
"cycles": [
{"table": "users", "from": "user_id", "to": "id"},
],
} }
with engine.connect() as conn: with engine.connect() as conn:
@@ -40,8 +79,8 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
text("SELECT name FROM sqlite_master WHERE type='table'") text("SELECT name FROM sqlite_master WHERE type='table'")
).fetchall() ).fetchall()
found_tables = {r[0] for r in rows} found_tables = {r[0] for r in rows}
assert set(expected_tables.keys()).issubset(found_tables), ( assert set(expected_schema.keys()).issubset(found_tables), (
f"missing tables: {set(expected_tables.keys()) - found_tables}" f"missing tables: {set(expected_schema.keys()) - found_tables}"
) )
# check user_version # check user_version
@@ -49,29 +88,37 @@ def test_run_migrations_0_to_1(monkeypatch: pytest.MonkeyPatch) -> None:
assert uv is not None assert uv is not None
assert int(uv[0]) == 1 assert int(uv[0]) == 1
# validate columns and (base) types for each expected table # validate each table columns
dialect = conn.dialect for tbl_name, cols in expected_schema.items():
for tbl_name, table in expected_tables.items():
info_rows = conn.execute(text(f"PRAGMA table_info({tbl_name})")).fetchall() info_rows = conn.execute(text(f"PRAGMA table_info({tbl_name})")).fetchall()
# build mapping: column name -> declared type (upper) # map: name -> (type, notnull, pk)
actual_cols = {r[1]: (r[2] or "").upper() for r in info_rows} actual = {
for col in table.columns: r[1]: ((r[2] or "").upper(), int(r[3]), int(r[5])) for r in info_rows
assert col.name in actual_cols, ( }
f"column {col.name} missing in table {tbl_name}" for colname, (exp_type, exp_notnull, exp_pk) in cols.items():
assert colname in actual, f"{tbl_name}: missing column {colname}"
act_type, act_notnull, act_pk = actual[colname]
# compare base type (e.g. VARCHAR(13) -> VARCHAR)
if act_type:
act_base = _base_type_of(act_type)
else:
act_base = ""
assert exp_type in act_base or act_base in exp_type, (
f"type mismatch {tbl_name}.{colname}: expected {exp_type}, got {act_base}"
) )
# compile expected type against this dialect assert act_notnull == exp_notnull, (
try: f"notnull mismatch {tbl_name}.{colname}: expected {exp_notnull}, got {act_notnull}"
compiled = col.type.compile(
dialect=dialect
) # e.g. VARCHAR(13), DATETIME
except Exception:
compiled = str(col.type)
expected_base = _base_type_of(compiled)
actual_type = actual_cols[col.name]
actual_base = _base_type_of(actual_type) if actual_type else ""
# accept either direction (some dialect vs sqlite naming differences)
assert (expected_base in actual_base) or (
actual_base in expected_base
), (
f"type mismatch for {tbl_name}.{col.name}: expected {expected_base}, got {actual_base}"
) )
assert act_pk == exp_pk, (
f"pk mismatch {tbl_name}.{colname}: expected {exp_pk}, got {act_pk}"
)
for tbl_name, fks in expected_fks.items():
fk_rows = conn.execute(
text(f"PRAGMA foreign_key_list('{tbl_name}')")
).fetchall()
# fk_rows columns: (id, seq, table, from, to, on_update, on_delete, match)
actual_fk_list = [
{"table": r[2], "from": r[3], "to": r[4]} for r in fk_rows
]
for efk in fks:
assert efk in actual_fk_list, f"missing FK on {tbl_name}: {efk}"

View File

@@ -0,0 +1 @@

View File

@@ -15,22 +15,43 @@ if TYPE_CHECKING:
class Database: class Database:
def __init__(self, database_url: str | None = None, *, echo: bool = False, connect_args: dict | None = None) -> None: def __init__(
self,
database_url: str | None = None,
*,
echo: bool = False,
connect_args: dict | None = None,
) -> None:
self._database_url = database_url or "sqlite:///:memory:" self._database_url = database_url or "sqlite:///:memory:"
default_connect = {"check_same_thread": False, "timeout": 30} if self._database_url.startswith("sqlite") else {} default_connect = (
{"check_same_thread": False, "timeout": 30}
if self._database_url.startswith("sqlite")
else {}
)
merged_connect = {**default_connect, **(connect_args or {})} merged_connect = {**default_connect, **(connect_args or {})}
if self._database_url == "sqlite:///:memory:": if self._database_url == "sqlite:///:memory:":
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.warning("Using in-memory SQLite database; all data will be lost when the application stops.") logger.warning(
self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect, poolclass=StaticPool) "Using in-memory SQLite database; all data will be lost when the application stops."
)
self._engine = create_engine(
self._database_url,
echo=echo,
connect_args=merged_connect,
poolclass=StaticPool,
)
else: else:
self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect) self._engine = create_engine(
self._database_url, echo=echo, connect_args=merged_connect
)
if self._database_url.startswith("sqlite"): if self._database_url.startswith("sqlite"):
def _enable_sqlite_pragmas(dbapi_conn: DBAPIConnection, _connection_record: object) -> None: def _enable_sqlite_pragmas(
dbapi_conn: DBAPIConnection, _connection_record: object
) -> None:
try: try:
cur = dbapi_conn.cursor() cur = dbapi_conn.cursor()
cur.execute("PRAGMA journal_mode=WAL;") cur.execute("PRAGMA journal_mode=WAL;")
@@ -62,5 +83,10 @@ class Database:
self._engine.dispose() self._engine.dispose()
def create_database(database_url: str | None = None, *, echo: bool = False, connect_args: dict | None = None) -> Database: def create_database(
database_url: str | None = None,
*,
echo: bool = False,
connect_args: dict | None = None,
) -> Database:
return Database(database_url, echo=echo, connect_args=connect_args) return Database(database_url, echo=echo, connect_args=connect_args)

View File

@@ -18,9 +18,16 @@ def _mig_0_1(engine: Engine) -> None:
""" """
# Ensure all models are imported before this is called (import side-effect registers tables) # Ensure all models are imported before this is called (import side-effect registers tables)
# e.g. trading_journal.models is imported in the caller / app startup. # e.g. trading_journal.models is imported in the caller / app startup.
from trading_journal import models_v1 # noqa: PLC0415, F401 from trading_journal import models_v1
SQLModel.metadata.create_all(bind=engine) SQLModel.metadata.create_all(
bind=engine,
tables=[
models_v1.Trades.__table__,
models_v1.Cycles.__table__,
models_v1.Users.__table__,
],
)
# map current_version -> function that migrates from current_version -> current_version+1 # map current_version -> function that migrates from current_version -> current_version+1
@@ -51,7 +58,9 @@ def run_migrations(engine: Engine, target_version: int | None = None) -> int:
while cur_version < target: while cur_version < target:
fn = MIGRATIONS.get(cur_version) fn = MIGRATIONS.get(cur_version)
if fn is None: if fn is None:
raise RuntimeError(f"No migration from {cur_version} -> {cur_version + 1}") raise RuntimeError(
f"No migration from {cur_version} -> {cur_version + 1}"
)
# call migration with Engine (fn should use transactions) # call migration with Engine (fn should use transactions)
fn(engine) fn(engine)
_set_sqlite_user_version(conn, cur_version + 1) _set_sqlite_user_version(conn, cur_version + 1)

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from datetime import date, datetime # noqa: TC003 from datetime import date, datetime # noqa: TC003
from enum import Enum from enum import Enum
from sqlalchemy import Date, Text, UniqueConstraint
from sqlmodel import Column, DateTime, Field, Relationship, SQLModel from sqlmodel import Column, DateTime, Field, Relationship, SQLModel
from sqlmodel import Enum as SQLEnum
class TradeType(str, Enum): class TradeType(str, Enum):
@@ -12,6 +12,18 @@ class TradeType(str, Enum):
ASSIGNMENT = "ASSIGNMENT" ASSIGNMENT = "ASSIGNMENT"
SELL_CALL = "SELL_CALL" SELL_CALL = "SELL_CALL"
EXERCISE_CALL = "EXERCISE_CALL" EXERCISE_CALL = "EXERCISE_CALL"
LONG_SPOT = "LONG_SPOT"
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
SHORT_SPOT = "SHORT_SPOT"
CLOSE_SHORT_SPOT = "CLOSE_SHORT_SPOT"
LONG_CFD = "LONG_CFD"
CLOSE_LONG_CFD = "CLOSE_LONG_CFD"
SHORT_CFD = "SHORT_CFD"
CLOSE_SHORT_CFD = "CLOSE_SHORT_CFD"
LONG_OTHER = "LONG_OTHER"
CLOSE_LONG_OTHER = "CLOSE_LONG_OTHER"
SHORT_OTHER = "SHORT_OTHER"
CLOSE_SHORT_OTHER = "CLOSE_SHORT_OTHER"
class TradeStrategy(str, Enum): class TradeStrategy(str, Enum):
@@ -34,13 +46,25 @@ class FundingSource(str, Enum):
class Trades(SQLModel, table=True): class Trades(SQLModel, table=True):
__tablename__ = "trades" __tablename__ = "trades"
id: str | None = Field(default=None, primary_key=True) __table_args__ = (
user_id: str UniqueConstraint(
symbol: str "user_id", "friendly_name", name="uq_trades_user_friendly_name"
underlying_currency: str ),
trade_type: TradeType = Field(sa_column=Column(SQLEnum(TradeType, name="trade_type_enum"), nullable=False)) )
trade_strategy: TradeStrategy = Field(sa_column=Column(SQLEnum(TradeStrategy, name="trade_strategy_enum"), nullable=False))
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
friendly_name: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
trade_time_utc: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
expiry_date: date | None = Field(default=None, nullable=True) expiry_date: date | None = Field(default=None, nullable=True)
strike_price_cents: int | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True)
quantity: int quantity: int
@@ -48,21 +72,41 @@ class Trades(SQLModel, table=True):
gross_cash_flow_cents: int gross_cash_flow_cents: int
commission_cents: int commission_cents: int
net_cash_flow_cents: int net_cash_flow_cents: int
cycle_id: str | None = Field(default=None, foreign_key="cycles.id", nullable=True) cycle_id: int | None = Field(
default=None, foreign_key="cycles.id", nullable=True, index=True
)
cycle: Cycles | None = Relationship(back_populates="trades") cycle: Cycles | None = Relationship(back_populates="trades")
class Cycles(SQLModel, table=True): class Cycles(SQLModel, table=True):
__tablename__ = "cycles" __tablename__ = "cycles"
id: str | None = Field(default=None, primary_key=True) __table_args__ = (
user_id: str UniqueConstraint(
symbol: str "user_id", "friendly_name", name="uq_cycles_user_friendly_name"
underlying_currency: str ),
start_date: date )
end_date: date | None = Field(default=None, nullable=True)
status: CycleStatus = Field(sa_column=Column(SQLEnum(CycleStatus, name="cycle_status_enum"), nullable=False)) id: int | None = Field(default=None, primary_key=True)
funding_source: FundingSource = Field(sa_column=Column(SQLEnum(FundingSource, name="funding_source_enum"), nullable=False)) user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
friendly_name: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False))
capital_exposure_cents: int capital_exposure_cents: int
loan_amount_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_bps: int | None = Field(default=None, nullable=True) loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
start_date: date = Field(sa_column=Column(Date, nullable=False))
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
trades: list[Trades] = Relationship(back_populates="cycle") trades: list[Trades] = Relationship(back_populates="cycle")
class Users(SQLModel, table=True):
__tablename__ = "users"
id: int | None = Field(default=None, primary_key=True)
# unique=True already creates an index; no need to also set index=True
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
password_hash: str = Field(sa_column=Column(Text, nullable=False))
is_active: bool = Field(default=True, nullable=False)

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from datetime import date, datetime # noqa: TC003 from datetime import date, datetime # noqa: TC003
from enum import Enum from enum import Enum
from sqlalchemy import Date, Text, UniqueConstraint
from sqlmodel import Column, DateTime, Field, Relationship, SQLModel from sqlmodel import Column, DateTime, Field, Relationship, SQLModel
from sqlmodel import Enum as SQLEnum
class TradeType(str, Enum): class TradeType(str, Enum):
@@ -12,6 +12,18 @@ class TradeType(str, Enum):
ASSIGNMENT = "ASSIGNMENT" ASSIGNMENT = "ASSIGNMENT"
SELL_CALL = "SELL_CALL" SELL_CALL = "SELL_CALL"
EXERCISE_CALL = "EXERCISE_CALL" EXERCISE_CALL = "EXERCISE_CALL"
LONG_SPOT = "LONG_SPOT"
CLOSE_LONG_SPOT = "CLOSE_LONG_SPOT"
SHORT_SPOT = "SHORT_SPOT"
CLOSE_SHORT_SPOT = "CLOSE_SHORT_SPOT"
LONG_CFD = "LONG_CFD"
CLOSE_LONG_CFD = "CLOSE_LONG_CFD"
SHORT_CFD = "SHORT_CFD"
CLOSE_SHORT_CFD = "CLOSE_SHORT_CFD"
LONG_OTHER = "LONG_OTHER"
CLOSE_LONG_OTHER = "CLOSE_LONG_OTHER"
SHORT_OTHER = "SHORT_OTHER"
CLOSE_SHORT_OTHER = "CLOSE_SHORT_OTHER"
class TradeStrategy(str, Enum): class TradeStrategy(str, Enum):
@@ -34,13 +46,25 @@ class FundingSource(str, Enum):
class Trades(SQLModel, table=True): class Trades(SQLModel, table=True):
__tablename__ = "trades" __tablename__ = "trades"
id: str | None = Field(default=None, primary_key=True) __table_args__ = (
user_id: str UniqueConstraint(
symbol: str "user_id", "friendly_name", name="uq_trades_user_friendly_name"
underlying_currency: str ),
trade_type: TradeType = Field(sa_column=Column(SQLEnum(TradeType, name="trade_type_enum"), nullable=False)) )
trade_strategy: TradeStrategy = Field(sa_column=Column(SQLEnum(TradeStrategy, name="trade_strategy_enum"), nullable=False))
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False)) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
# allow null while user may omit friendly_name; uniqueness enforced per-user by constraint
friendly_name: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
trade_type: TradeType = Field(sa_column=Column(Text, nullable=False))
trade_strategy: TradeStrategy = Field(sa_column=Column(Text, nullable=False))
trade_time_utc: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False)
)
expiry_date: date | None = Field(default=None, nullable=True) expiry_date: date | None = Field(default=None, nullable=True)
strike_price_cents: int | None = Field(default=None, nullable=True) strike_price_cents: int | None = Field(default=None, nullable=True)
quantity: int quantity: int
@@ -48,21 +72,41 @@ class Trades(SQLModel, table=True):
gross_cash_flow_cents: int gross_cash_flow_cents: int
commission_cents: int commission_cents: int
net_cash_flow_cents: int net_cash_flow_cents: int
cycle_id: str | None = Field(default=None, foreign_key="cycles.id", nullable=True) cycle_id: int | None = Field(
default=None, foreign_key="cycles.id", nullable=True, index=True
)
cycle: Cycles | None = Relationship(back_populates="trades") cycle: Cycles | None = Relationship(back_populates="trades")
class Cycles(SQLModel, table=True): class Cycles(SQLModel, table=True):
__tablename__ = "cycles" __tablename__ = "cycles"
id: str | None = Field(default=None, primary_key=True) __table_args__ = (
user_id: str UniqueConstraint(
symbol: str "user_id", "friendly_name", name="uq_cycles_user_friendly_name"
underlying_currency: str ),
start_date: date )
end_date: date | None = Field(default=None, nullable=True)
status: CycleStatus = Field(sa_column=Column(SQLEnum(CycleStatus, name="cycle_status_enum"), nullable=False)) id: int | None = Field(default=None, primary_key=True)
funding_source: FundingSource = Field(sa_column=Column(SQLEnum(FundingSource, name="funding_source_enum"), nullable=False)) user_id: int = Field(foreign_key="users.id", nullable=False, index=True)
friendly_name: str | None = Field(
default=None, sa_column=Column(Text, nullable=True)
)
symbol: str = Field(sa_column=Column(Text, nullable=False))
underlying_currency: str = Field(sa_column=Column(Text, nullable=False))
status: CycleStatus = Field(sa_column=Column(Text, nullable=False))
funding_source: FundingSource = Field(sa_column=Column(Text, nullable=False))
capital_exposure_cents: int capital_exposure_cents: int
loan_amount_cents: int | None = Field(default=None, nullable=True) loan_amount_cents: int | None = Field(default=None, nullable=True)
loan_interest_rate_bps: int | None = Field(default=None, nullable=True) loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
start_date: date = Field(sa_column=Column(Date, nullable=False))
end_date: date | None = Field(default=None, sa_column=Column(Date, nullable=True))
trades: list[Trades] = Relationship(back_populates="cycle") trades: list[Trades] = Relationship(back_populates="cycle")
class Users(SQLModel, table=True):
__tablename__ = "users"
id: int | None = Field(default=None, primary_key=True)
# unique=True already creates an index; no need to also set index=True
username: str = Field(sa_column=Column(Text, nullable=False, unique=True))
password_hash: str = Field(sa_column=Column(Text, nullable=False))
is_active: bool = Field(default=True, nullable=False)