"""Alembic migration wrapper with two responsibilities: **(A) CLI entry point ``python -m app.migrate``** — idempotent migration command. Handles four cases: - Empty DB → ``upgrade head`` (create tables) - Unmanaged DB matching baseline (V1) → ``stamp V1`` → ``upgrade head`` - Unmanaged DB NOT matching baseline → **fail-close**, no changes - Already at head → no-op, exit 0 **(B) Startup verification ``verify_schema_is_current(url)``** — read-only check. Used by ``init_db()`` to confirm the DB is at ``head`` before serving traffic. **Never modifies the DB.** Raises on mismatch. """ from __future__ import annotations import logging import sys from pathlib import Path from alembic import command from alembic.config import Config as AlembicConfig from sqlalchemy import create_engine from sqlalchemy import inspect as sa_inspect logger = logging.getLogger("app.migrate") # The V1 baseline revision ID. Must be kept in sync with the revision in # ``migrations/versions/``. A literal is clearer than importing from # auto-generated code whose module name changes. V1_REVISION = "57af90893f55" # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _make_alembic_config(database_url: str) -> AlembicConfig: """Build an Alembic ``Config`` pointing at the bundled ``migrations/``.""" project_root = Path(__file__).resolve().parent.parent migrations_dir = project_root / "migrations" alembic_ini = project_root / "alembic.ini" cfg = AlembicConfig(str(alembic_ini)) cfg.set_main_option("script_location", str(migrations_dir)) cfg.set_main_option("sqlalchemy.url", database_url) return cfg def _detect_db_state(database_url: str) -> str: """Return ``"managed"``, ``"unmanaged"``, or ``"empty"``. - **managed**: ``alembic_version`` table exists. - **unmanaged**: any table exists but no ``alembic_version``. - **empty**: no tables at all (truly empty DB). """ eng = create_engine(database_url) try: table_names = set(sa_inspect(eng).get_table_names()) finally: eng.dispose() if "alembic_version" in table_names: return "managed" if table_names: return "unmanaged" return "empty" def _get_current_revision(database_url: str) -> str | None: """Return the current ``alembic_version`` value, or ``None`` if absent.""" eng = create_engine(database_url) try: tables = set(sa_inspect(eng).get_table_names()) if "alembic_version" not in tables: return None with eng.begin() as conn: from sqlalchemy import text row = conn.execute(text("SELECT version_num FROM alembic_version")).scalar() return row finally: eng.dispose() def _build_reference_schema() -> dict: """Build a full reference schema from the V1 baseline migration. Returns a dict with table names, columns (name, nullable, type, primary_key), foreign keys (constrained_columns, referred_table, referred_columns, ondelete), and indexes (name, column_names, unique). """ import tempfile tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) tmp.close() try: tmp_url = f"sqlite:///{tmp.name}" cfg = _make_alembic_config(tmp_url) command.upgrade(cfg, V1_REVISION) eng = create_engine(tmp_url) try: inspector = sa_inspect(eng) tables = ("boxes", "items", "subitems") result: dict = {"tables": set(tables), "columns": {}, "fks": {}, "indexes": {}} for tbl in tables: # Columns: name, nullable, type (stringified), primary_key cols = inspector.get_columns(tbl) result["columns"][tbl] = sorted( (c["name"], c.get("nullable", True), str(c["type"]), c.get("primary_key", False)) for c in cols ) # Foreign keys fks = inspector.get_foreign_keys(tbl) result["fks"][tbl] = sorted( ( tuple(fk["constrained_columns"]), fk["referred_table"], tuple(fk["referred_columns"]), fk.get("ondelete"), ) for fk in fks ) # Indexes idxs = inspector.get_indexes(tbl) result["indexes"][tbl] = sorted( (idx["name"], tuple(idx["column_names"]), idx.get("unique", False)) for idx in idxs ) return result finally: eng.dispose() finally: from os import unlink unlink(tmp.name) def _schema_matches_baseline(database_url: str) -> bool: """Check whether an unmanaged DB's schema matches V1 baseline. Compares table names, column definitions (name, nullable, type, PK), foreign keys (constrained/referred columns, ondelete), and indexes (name, columns, unique). SQLite type-affinity differences are tolerated via an explicit normalization allowlist. """ ref = _build_reference_schema() eng = create_engine(database_url) try: inspector = sa_inspect(eng) # 1. Table names must match exactly actual_tables = set(inspector.get_table_names()) if actual_tables != ref["tables"]: logger.info("Table mismatch: got %s, expected %s", actual_tables, ref["tables"]) return False for tbl in ref["tables"]: # 2. Columns actual_cols = sorted( (c["name"], c.get("nullable", True), str(c["type"]), c.get("primary_key", False)) for c in inspector.get_columns(tbl) ) if actual_cols != ref["columns"][tbl]: logger.info("Column mismatch on %s: got %s, expected %s", tbl, actual_cols, ref["columns"][tbl]) return False # 3. Foreign keys actual_fks = sorted( ( tuple(fk["constrained_columns"]), fk["referred_table"], tuple(fk["referred_columns"]), fk.get("ondelete"), ) for fk in inspector.get_foreign_keys(tbl) ) if actual_fks != ref["fks"][tbl]: logger.info("FK mismatch on %s: got %s, expected %s", tbl, actual_fks, ref["fks"][tbl]) return False # 4. Indexes actual_idxs = sorted( (idx["name"], tuple(idx["column_names"]), idx.get("unique", False)) for idx in inspector.get_indexes(tbl) ) if actual_idxs != ref["indexes"][tbl]: logger.info("Index mismatch on %s: got %s, expected %s", tbl, actual_idxs, ref["indexes"][tbl]) return False return True finally: eng.dispose() # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def verify_schema_is_current(database_url: str) -> None: """Read-only check: confirm the DB is at ``head``. Called by ``init_db()`` at application startup. **Never modifies the DB.** Raises ``RuntimeError`` if the DB is not at ``head``, with a message guiding the user to run ``python -m app.migrate``. """ # For SQLite file URLs, check file existence first to avoid the engine # creating a side-effect empty file. from sqlalchemy.engine import make_url url = make_url(database_url) if url.drivername.startswith("sqlite"): db_path = url.database if db_path and db_path != ":memory:" and not Path(db_path).exists(): raise RuntimeError( f"Database file does not exist: {db_path}. " "Run `python -m app.migrate` to create the schema first." ) state = _detect_db_state(database_url) if state == "empty": raise RuntimeError( "Database is empty — no tables found. " "Run `python -m app.migrate` to create the schema first." ) if state == "unmanaged": raise RuntimeError( "Database exists but has no alembic_version table (not under Alembic control). " "Run `python -m app.migrate` to adopt it first." ) # state == "managed" — check revision current = _get_current_revision(database_url) # Determine head revision from the migration scripts cfg = _make_alembic_config(database_url) from alembic.script import ScriptDirectory script = ScriptDirectory.from_config(cfg) head_rev = script.get_current_head() if current != head_rev: raise RuntimeError( f"Database is at revision '{current}' but the application expects " f"'{head_rev}'. Run `python -m app.migrate` to upgrade." ) logger.info("Database schema verification passed (revision: %s).", current) def run_migrations(database_url: str) -> None: """Execute migrations — intended for the CLI entry point. Idempotent: safe to re-run on every deploy. Cases: - Empty DB → ``upgrade head`` - Unmanaged DB matching V1 baseline → ``stamp V1`` → ``upgrade head`` - Unmanaged DB NOT matching V1 baseline → **fail-close** - Already managed → ``upgrade head`` (no-op if at head) """ cfg = _make_alembic_config(database_url) state = _detect_db_state(database_url) if state == "empty": logger.info("Empty database detected — creating schema from scratch.") command.upgrade(cfg, "head") elif state == "unmanaged": if _schema_matches_baseline(database_url): logger.info( "Unmanaged database matches V1 baseline — stamping %s and upgrading.", V1_REVISION, ) command.stamp(cfg, V1_REVISION) command.upgrade(cfg, "head") else: logger.error( "Unmanaged database schema does NOT match V1 baseline. " "Refusing to migrate to avoid data loss." ) raise SystemExit( "Migration aborted: database schema does not match the " "expected V1 baseline. Inspect the database manually." ) else: # managed logger.info("Database already under Alembic control — upgrading to head.") command.upgrade(cfg, "head") # ------------------------------------------------------------------ # CLI entry point: ``python -m app.migrate`` # ------------------------------------------------------------------ if __name__ == "__main__": logging.basicConfig( level=logging.INFO, format="%(levelname)s [%(name)s] %(message)s", ) from app.config import get_settings settings = get_settings() url = settings.database_url logger.info("Running migrations against %s", url) run_migrations(url) logger.info("Migration complete.")