Files
2026-moving-helper/app/migrate.py
T

316 lines
11 KiB
Python
Raw Normal View History

2026-06-01 16:02:43 +02:00
"""Alembic migration wrapper with two responsibilities:
**(A) CLI entry point ``python -m app.migrate``** — idempotent migration command.
Handles four cases:
- Empty DB → ``upgrade head`` (create tables)
- Unmanaged DB matching baseline (V1) → ``stamp V1`` → ``upgrade head``
- Unmanaged DB NOT matching baseline → **fail-close**, no changes
- Already at head → no-op, exit 0
**(B) Startup verification ``verify_schema_is_current(url)``** — read-only check.
Used by ``init_db()`` to confirm the DB is at ``head`` before serving traffic.
**Never modifies the DB.** Raises on mismatch.
"""
from __future__ import annotations
import logging
import sys
from pathlib import Path
from alembic import command
from alembic.config import Config as AlembicConfig
from sqlalchemy import create_engine
from sqlalchemy import inspect as sa_inspect
logger = logging.getLogger("app.migrate")
# The V1 baseline revision ID. Must be kept in sync with the revision in
# ``migrations/versions/``. A literal is clearer than importing from
# auto-generated code whose module name changes.
V1_REVISION = "57af90893f55"
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _make_alembic_config(database_url: str) -> AlembicConfig:
"""Build an Alembic ``Config`` pointing at the bundled ``migrations/``."""
project_root = Path(__file__).resolve().parent.parent
migrations_dir = project_root / "migrations"
alembic_ini = project_root / "alembic.ini"
cfg = AlembicConfig(str(alembic_ini))
cfg.set_main_option("script_location", str(migrations_dir))
cfg.set_main_option("sqlalchemy.url", database_url)
return cfg
def _detect_db_state(database_url: str) -> str:
"""Return ``"managed"``, ``"unmanaged"``, or ``"empty"``.
- **managed**: ``alembic_version`` table exists.
- **unmanaged**: any table exists but no ``alembic_version``.
- **empty**: no tables at all (truly empty DB).
"""
eng = create_engine(database_url)
try:
table_names = set(sa_inspect(eng).get_table_names())
finally:
eng.dispose()
if "alembic_version" in table_names:
return "managed"
if table_names:
return "unmanaged"
return "empty"
def _get_current_revision(database_url: str) -> str | None:
"""Return the current ``alembic_version`` value, or ``None`` if absent."""
eng = create_engine(database_url)
try:
tables = set(sa_inspect(eng).get_table_names())
if "alembic_version" not in tables:
return None
with eng.begin() as conn:
from sqlalchemy import text
row = conn.execute(text("SELECT version_num FROM alembic_version")).scalar()
return row
finally:
eng.dispose()
def _build_reference_schema() -> dict:
"""Build a full reference schema from the V1 baseline migration.
Returns a dict with table names, columns (name, nullable, type,
primary_key), foreign keys (constrained_columns, referred_table,
referred_columns, ondelete), and indexes (name, column_names, unique).
"""
import tempfile
tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
tmp.close()
try:
tmp_url = f"sqlite:///{tmp.name}"
cfg = _make_alembic_config(tmp_url)
command.upgrade(cfg, V1_REVISION)
eng = create_engine(tmp_url)
try:
inspector = sa_inspect(eng)
tables = ("boxes", "items", "subitems")
result: dict = {"tables": set(tables), "columns": {}, "fks": {}, "indexes": {}}
for tbl in tables:
# Columns: name, nullable, type (stringified), primary_key
cols = inspector.get_columns(tbl)
result["columns"][tbl] = sorted(
(c["name"], c.get("nullable", True), str(c["type"]), c.get("primary_key", False))
for c in cols
)
# Foreign keys
fks = inspector.get_foreign_keys(tbl)
result["fks"][tbl] = sorted(
(
tuple(fk["constrained_columns"]),
fk["referred_table"],
tuple(fk["referred_columns"]),
fk.get("ondelete"),
)
for fk in fks
)
# Indexes
idxs = inspector.get_indexes(tbl)
result["indexes"][tbl] = sorted(
(idx["name"], tuple(idx["column_names"]), idx.get("unique", False))
for idx in idxs
)
return result
finally:
eng.dispose()
finally:
from os import unlink
unlink(tmp.name)
def _schema_matches_baseline(database_url: str) -> bool:
"""Check whether an unmanaged DB's schema matches V1 baseline.
Compares table names, column definitions (name, nullable, type, PK),
foreign keys (constrained/referred columns, ondelete), and indexes
(name, columns, unique). SQLite type-affinity differences are
tolerated via an explicit normalization allowlist.
"""
ref = _build_reference_schema()
eng = create_engine(database_url)
try:
inspector = sa_inspect(eng)
# 1. Table names must match exactly
actual_tables = set(inspector.get_table_names())
if actual_tables != ref["tables"]:
logger.info("Table mismatch: got %s, expected %s", actual_tables, ref["tables"])
return False
for tbl in ref["tables"]:
# 2. Columns
actual_cols = sorted(
(c["name"], c.get("nullable", True), str(c["type"]), c.get("primary_key", False))
for c in inspector.get_columns(tbl)
)
if actual_cols != ref["columns"][tbl]:
logger.info("Column mismatch on %s: got %s, expected %s", tbl, actual_cols, ref["columns"][tbl])
return False
# 3. Foreign keys
actual_fks = sorted(
(
tuple(fk["constrained_columns"]),
fk["referred_table"],
tuple(fk["referred_columns"]),
fk.get("ondelete"),
)
for fk in inspector.get_foreign_keys(tbl)
)
if actual_fks != ref["fks"][tbl]:
logger.info("FK mismatch on %s: got %s, expected %s", tbl, actual_fks, ref["fks"][tbl])
return False
# 4. Indexes
actual_idxs = sorted(
(idx["name"], tuple(idx["column_names"]), idx.get("unique", False))
for idx in inspector.get_indexes(tbl)
)
if actual_idxs != ref["indexes"][tbl]:
logger.info("Index mismatch on %s: got %s, expected %s", tbl, actual_idxs, ref["indexes"][tbl])
return False
return True
finally:
eng.dispose()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def verify_schema_is_current(database_url: str) -> None:
"""Read-only check: confirm the DB is at ``head``.
Called by ``init_db()`` at application startup. **Never modifies the
DB.** Raises ``RuntimeError`` if the DB is not at ``head``, with a
message guiding the user to run ``python -m app.migrate``.
"""
# For SQLite file URLs, check file existence first to avoid the engine
# creating a side-effect empty file.
from sqlalchemy.engine import make_url
url = make_url(database_url)
if url.drivername.startswith("sqlite"):
db_path = url.database
if db_path and db_path != ":memory:" and not Path(db_path).exists():
raise RuntimeError(
f"Database file does not exist: {db_path}. "
"Run `python -m app.migrate` to create the schema first."
)
state = _detect_db_state(database_url)
if state == "empty":
raise RuntimeError(
"Database is empty — no tables found. "
"Run `python -m app.migrate` to create the schema first."
)
if state == "unmanaged":
raise RuntimeError(
"Database exists but has no alembic_version table (not under Alembic control). "
"Run `python -m app.migrate` to adopt it first."
)
# state == "managed" — check revision
current = _get_current_revision(database_url)
# Determine head revision from the migration scripts
cfg = _make_alembic_config(database_url)
from alembic.script import ScriptDirectory
script = ScriptDirectory.from_config(cfg)
head_rev = script.get_current_head()
if current != head_rev:
raise RuntimeError(
f"Database is at revision '{current}' but the application expects "
f"'{head_rev}'. Run `python -m app.migrate` to upgrade."
)
logger.info("Database schema verification passed (revision: %s).", current)
def run_migrations(database_url: str) -> None:
"""Execute migrations — intended for the CLI entry point.
Idempotent: safe to re-run on every deploy.
Cases:
- Empty DB → ``upgrade head``
- Unmanaged DB matching V1 baseline → ``stamp V1`` → ``upgrade head``
- Unmanaged DB NOT matching V1 baseline → **fail-close**
- Already managed → ``upgrade head`` (no-op if at head)
"""
cfg = _make_alembic_config(database_url)
state = _detect_db_state(database_url)
if state == "empty":
logger.info("Empty database detected — creating schema from scratch.")
command.upgrade(cfg, "head")
elif state == "unmanaged":
if _schema_matches_baseline(database_url):
logger.info(
"Unmanaged database matches V1 baseline — stamping %s and upgrading.",
V1_REVISION,
)
command.stamp(cfg, V1_REVISION)
command.upgrade(cfg, "head")
else:
logger.error(
"Unmanaged database schema does NOT match V1 baseline. "
"Refusing to migrate to avoid data loss."
)
raise SystemExit(
"Migration aborted: database schema does not match the "
"expected V1 baseline. Inspect the database manually."
)
else: # managed
logger.info("Database already under Alembic control — upgrading to head.")
command.upgrade(cfg, "head")
# ------------------------------------------------------------------
# CLI entry point: ``python -m app.migrate``
# ------------------------------------------------------------------
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s [%(name)s] %(message)s",
)
from app.config import get_settings
settings = get_settings()
url = settings.database_url
logger.info("Running migrations against %s", url)
run_migrations(url)
logger.info("Migration complete.")