another wip
This commit is contained in:
65
backend/trading_journal/db.py
Normal file
65
backend/trading_journal/db.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, database_url: str | None = None, *, echo: bool = False, connect_args: dict | None = None) -> None:
|
||||
self._database_url = database_url or "sqlite:///:memory:"
|
||||
|
||||
default_connect = {"check_same_thread": False, "timeout": 30} if self._database_url.startswith("sqlite") else {}
|
||||
merged_connect = {**default_connect, **(connect_args or {})}
|
||||
|
||||
if self._database_url == "sqlite:///:memory:":
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning("Using in-memory SQLite database; all data will be lost when the application stops.")
|
||||
self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect, poolclass=StaticPool)
|
||||
else:
|
||||
self._engine = create_engine(self._database_url, echo=echo, connect_args=merged_connect)
|
||||
|
||||
if self._database_url.startswith("sqlite"):
|
||||
|
||||
def _enable_sqlite_pragmas(dbapi_conn: Connection, _connection_record: object) -> None:
|
||||
try:
|
||||
cur = dbapi_conn.cursor()
|
||||
cur.execute("PRAGMA journal_mode=WAL;")
|
||||
cur.execute("PRAGMA synchronous=NORMAL;")
|
||||
cur.execute("PRAGMA foreign_keys=ON;")
|
||||
cur.execute("PRAGMA busy_timeout=30000;")
|
||||
cur.close()
|
||||
except Exception:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.exception("Failed to set sqlite pragmas on new connection: ")
|
||||
|
||||
event.listen(self._engine, "connect", _enable_sqlite_pragmas)
|
||||
|
||||
def init_db(self) -> None:
|
||||
pass
|
||||
|
||||
def get_session(self) -> Generator[Session, None, None]:
|
||||
session = Session(self._engine)
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def dispose(self) -> None:
|
||||
self._engine.dispose()
|
||||
|
||||
|
||||
def create_database(database_url: str | None = None, *, echo: bool = False, connect_args: dict | None = None) -> Database:
|
||||
return Database(database_url, echo=echo, connect_args=connect_args)
|
||||
76
backend/trading_journal/db_migration.py
Normal file
76
backend/trading_journal/db_migration.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
# 最新 schema 版本号
|
||||
LATEST_VERSION = 1
|
||||
|
||||
|
||||
def _mig_0_1(engine: Engine) -> None:
|
||||
"""
|
||||
Initial schema: create all tables from SQLModel models.
|
||||
Safe to call on an empty DB; idempotent for missing tables.
|
||||
"""
|
||||
# Ensure all models are imported before this is called (import side-effect registers tables)
|
||||
# e.g. trading_journal.models is imported in the caller / app startup.
|
||||
SQLModel.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
# map current_version -> function that migrates from current_version -> current_version+1
|
||||
MIGRATIONS: dict[int, Callable[[Engine], None]] = {
|
||||
0: _mig_0_1,
|
||||
}
|
||||
|
||||
|
||||
def _get_sqlite_user_version(conn) -> int:
|
||||
row = conn.execute(text("PRAGMA user_version")).fetchone()
|
||||
return int(row[0]) if row and row[0] is not None else 0
|
||||
|
||||
|
||||
def _set_sqlite_user_version(conn, v: int) -> None:
|
||||
conn.execute(text(f"PRAGMA user_version = {int(v)}"))
|
||||
|
||||
|
||||
def run_migrations(engine: Engine, target_version: int | None = None) -> int:
|
||||
"""
|
||||
Run migrations up to target_version (or LATEST_VERSION).
|
||||
Returns final applied version.
|
||||
"""
|
||||
target = target_version or LATEST_VERSION
|
||||
with engine.begin() as conn:
|
||||
driver = conn.engine.name.lower()
|
||||
if driver == "sqlite":
|
||||
cur_version = _get_sqlite_user_version(conn)
|
||||
while cur_version < target:
|
||||
fn = MIGRATIONS.get(cur_version)
|
||||
if fn is None:
|
||||
raise RuntimeError(f"No migration from {cur_version} -> {cur_version + 1}")
|
||||
# call migration with Engine (fn should use transactions)
|
||||
fn(engine)
|
||||
_set_sqlite_user_version(conn, cur_version + 1)
|
||||
cur_version += 1
|
||||
return cur_version
|
||||
else:
|
||||
# generic migrations table for non-sqlite
|
||||
conn.execute(
|
||||
text("""
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
)
|
||||
row = conn.execute(text("SELECT MAX(version) FROM migrations")).fetchone()
|
||||
cur_version = int(row[0]) if row and row[0] is not None else 0
|
||||
while cur_version < target:
|
||||
fn = MIGRATIONS.get(cur_version)
|
||||
if fn is None:
|
||||
raise RuntimeError(f"No migration from {cur_version} -> {cur_version + 1}")
|
||||
fn(engine)
|
||||
conn.execute(text("INSERT INTO migrations(version) VALUES (:v)"), {"v": cur_version + 1})
|
||||
cur_version += 1
|
||||
return cur_version
|
||||
@@ -1,14 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime # noqa: TC003
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Column, DateTime, Field, Relationship, SQLModel
|
||||
from sqlmodel import Enum as SQLEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import date, datetime
|
||||
|
||||
|
||||
class TradeType(str, Enum):
|
||||
SELL_PUT = "SELL_PUT"
|
||||
@@ -30,12 +27,12 @@ class FundingSource(str, Enum):
|
||||
|
||||
class Trades(SQLModel, table=True):
|
||||
__tablename__ = "trades"
|
||||
id: str = Field(default=None, primary_key=True)
|
||||
id: str | None = Field(default=None, primary_key=True)
|
||||
user_id: str
|
||||
symbol: str
|
||||
underlying_currency: str
|
||||
trade_type: TradeType = Field(sa_column=Column(SQLEnum(TradeType, name="trade_type_enum")), nullable=False)
|
||||
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True)), nullable=False)
|
||||
trade_type: TradeType = Field(sa_column=Column(SQLEnum(TradeType, name="trade_type_enum"), nullable=False))
|
||||
trade_time_utc: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
|
||||
expiry_date: date | None = Field(default=None, nullable=True)
|
||||
strike_price_cents: int | None = Field(default=None, nullable=True)
|
||||
quantity: int
|
||||
@@ -49,14 +46,14 @@ class Trades(SQLModel, table=True):
|
||||
|
||||
class Cycles(SQLModel, table=True):
|
||||
__tablename__ = "cycles"
|
||||
id: str = Field(default=None, primary_key=True)
|
||||
id: str | None = Field(default=None, primary_key=True)
|
||||
user_id: str
|
||||
symbol: str
|
||||
underlying_currency: str
|
||||
start_date: date
|
||||
end_date: date | None = Field(default=None, nullable=True)
|
||||
status: CycleStatus = Field(sa_column=Column(SQLEnum(CycleStatus, name="cycle_status_enum")), nullable=False)
|
||||
funding_source: FundingSource = Field(sa_column=Column(SQLEnum(FundingSource, name="funding_source_enum")), nullable=False)
|
||||
status: CycleStatus = Field(sa_column=Column(SQLEnum(CycleStatus, name="cycle_status_enum"), nullable=False))
|
||||
funding_source: FundingSource = Field(sa_column=Column(SQLEnum(FundingSource, name="funding_source_enum"), nullable=False))
|
||||
capital_exposure_cents: int
|
||||
loan_amount_cents: int | None = Field(default=None, nullable=True)
|
||||
loan_interest_rate_bps: int | None = Field(default=None, nullable=True)
|
||||
|
||||
Reference in New Issue
Block a user