2025-09-12 21:19:36 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2025-09-13 18:46:16 +02:00
|
|
|
from typing import TYPE_CHECKING, Callable
|
2025-09-12 21:19:36 +00:00
|
|
|
|
|
|
|
|
from sqlalchemy import text
|
|
|
|
|
from sqlmodel import SQLModel
|
|
|
|
|
|
2025-09-13 18:46:16 +02:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from sqlalchemy.engine import Connection, Engine
|
|
|
|
|
|
2025-09-12 21:19:36 +00:00
|
|
|
LATEST_VERSION = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _mig_0_1(engine: Engine) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Initial schema: create all tables from SQLModel models.
|
|
|
|
|
Safe to call on an empty DB; idempotent for missing tables.
|
|
|
|
|
"""
|
|
|
|
|
# Ensure all models are imported before this is called (import side-effect registers tables)
|
|
|
|
|
# e.g. trading_journal.models is imported in the caller / app startup.
|
2025-09-14 15:40:11 +02:00
|
|
|
from trading_journal import models_v1
|
2025-09-13 18:46:16 +02:00
|
|
|
|
2025-09-14 15:40:11 +02:00
|
|
|
SQLModel.metadata.create_all(
|
|
|
|
|
bind=engine,
|
|
|
|
|
tables=[
|
|
|
|
|
models_v1.Trades.__table__,
|
|
|
|
|
models_v1.Cycles.__table__,
|
|
|
|
|
models_v1.Users.__table__,
|
|
|
|
|
],
|
|
|
|
|
)
|
2025-09-12 21:19:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# map current_version -> function that migrates from current_version -> current_version+1
|
|
|
|
|
MIGRATIONS: dict[int, Callable[[Engine], None]] = {
|
|
|
|
|
0: _mig_0_1,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2025-09-13 12:58:46 +02:00
|
|
|
def _get_sqlite_user_version(conn: Connection) -> int:
|
2025-09-12 21:19:36 +00:00
|
|
|
row = conn.execute(text("PRAGMA user_version")).fetchone()
|
|
|
|
|
return int(row[0]) if row and row[0] is not None else 0
|
|
|
|
|
|
|
|
|
|
|
2025-09-13 18:46:16 +02:00
|
|
|
def _set_sqlite_user_version(conn: Connection, v: int) -> None:
|
2025-09-12 21:19:36 +00:00
|
|
|
conn.execute(text(f"PRAGMA user_version = {int(v)}"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_migrations(engine: Engine, target_version: int | None = None) -> int:
|
|
|
|
|
"""
|
|
|
|
|
Run migrations up to target_version (or LATEST_VERSION).
|
|
|
|
|
Returns final applied version.
|
|
|
|
|
"""
|
|
|
|
|
target = target_version or LATEST_VERSION
|
|
|
|
|
with engine.begin() as conn:
|
|
|
|
|
driver = conn.engine.name.lower()
|
|
|
|
|
if driver == "sqlite":
|
|
|
|
|
cur_version = _get_sqlite_user_version(conn)
|
|
|
|
|
while cur_version < target:
|
|
|
|
|
fn = MIGRATIONS.get(cur_version)
|
|
|
|
|
if fn is None:
|
2025-09-14 15:40:11 +02:00
|
|
|
raise RuntimeError(
|
|
|
|
|
f"No migration from {cur_version} -> {cur_version + 1}"
|
|
|
|
|
)
|
2025-09-12 21:19:36 +00:00
|
|
|
# call migration with Engine (fn should use transactions)
|
|
|
|
|
fn(engine)
|
|
|
|
|
_set_sqlite_user_version(conn, cur_version + 1)
|
|
|
|
|
cur_version += 1
|
|
|
|
|
return cur_version
|
2025-09-13 18:46:16 +02:00
|
|
|
return -1 # unknown / unsupported driver; no-op
|