316 lines
11 KiB
Python
316 lines
11 KiB
Python
"""Alembic migration wrapper with two responsibilities:
|
|
|
|
**(A) CLI entry point ``python -m app.migrate``** — idempotent migration command.
|
|
Handles four cases:
|
|
- Empty DB → ``upgrade head`` (create tables)
|
|
- Unmanaged DB matching baseline (V1) → ``stamp V1`` → ``upgrade head``
|
|
- Unmanaged DB NOT matching baseline → **fail-close**, no changes
|
|
- Already at head → no-op, exit 0
|
|
|
|
**(B) Startup verification ``verify_schema_is_current(url)``** — read-only check.
|
|
Used by ``init_db()`` to confirm the DB is at ``head`` before serving traffic.
|
|
**Never modifies the DB.** Raises on mismatch.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from alembic import command
|
|
from alembic.config import Config as AlembicConfig
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy import inspect as sa_inspect
|
|
|
|
logger = logging.getLogger("app.migrate")
|
|
|
|
# The V1 baseline revision ID. Must be kept in sync with the revision in
|
|
# ``migrations/versions/``. A literal is clearer than importing from
|
|
# auto-generated code whose module name changes.
|
|
V1_REVISION = "57af90893f55"
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
def _make_alembic_config(database_url: str) -> AlembicConfig:
|
|
"""Build an Alembic ``Config`` pointing at the bundled ``migrations/``."""
|
|
project_root = Path(__file__).resolve().parent.parent
|
|
migrations_dir = project_root / "migrations"
|
|
alembic_ini = project_root / "alembic.ini"
|
|
|
|
cfg = AlembicConfig(str(alembic_ini))
|
|
cfg.set_main_option("script_location", str(migrations_dir))
|
|
cfg.set_main_option("sqlalchemy.url", database_url)
|
|
return cfg
|
|
|
|
|
|
def _detect_db_state(database_url: str) -> str:
|
|
"""Return ``"managed"``, ``"unmanaged"``, or ``"empty"``.
|
|
|
|
- **managed**: ``alembic_version`` table exists.
|
|
- **unmanaged**: any table exists but no ``alembic_version``.
|
|
- **empty**: no tables at all (truly empty DB).
|
|
"""
|
|
eng = create_engine(database_url)
|
|
try:
|
|
table_names = set(sa_inspect(eng).get_table_names())
|
|
finally:
|
|
eng.dispose()
|
|
|
|
if "alembic_version" in table_names:
|
|
return "managed"
|
|
if table_names:
|
|
return "unmanaged"
|
|
return "empty"
|
|
|
|
|
|
def _get_current_revision(database_url: str) -> str | None:
|
|
"""Return the current ``alembic_version`` value, or ``None`` if absent."""
|
|
eng = create_engine(database_url)
|
|
try:
|
|
tables = set(sa_inspect(eng).get_table_names())
|
|
if "alembic_version" not in tables:
|
|
return None
|
|
with eng.begin() as conn:
|
|
from sqlalchemy import text
|
|
|
|
row = conn.execute(text("SELECT version_num FROM alembic_version")).scalar()
|
|
return row
|
|
finally:
|
|
eng.dispose()
|
|
|
|
|
|
def _build_reference_schema() -> dict:
|
|
"""Build a full reference schema from the V1 baseline migration.
|
|
|
|
Returns a dict with table names, columns (name, nullable, type,
|
|
primary_key), foreign keys (constrained_columns, referred_table,
|
|
referred_columns, ondelete), and indexes (name, column_names, unique).
|
|
"""
|
|
import tempfile
|
|
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
|
tmp.close()
|
|
try:
|
|
tmp_url = f"sqlite:///{tmp.name}"
|
|
cfg = _make_alembic_config(tmp_url)
|
|
command.upgrade(cfg, V1_REVISION)
|
|
|
|
eng = create_engine(tmp_url)
|
|
try:
|
|
inspector = sa_inspect(eng)
|
|
tables = ("boxes", "items", "subitems")
|
|
result: dict = {"tables": set(tables), "columns": {}, "fks": {}, "indexes": {}}
|
|
|
|
for tbl in tables:
|
|
# Columns: name, nullable, type (stringified), primary_key
|
|
cols = inspector.get_columns(tbl)
|
|
result["columns"][tbl] = sorted(
|
|
(c["name"], c.get("nullable", True), str(c["type"]), c.get("primary_key", False))
|
|
for c in cols
|
|
)
|
|
|
|
# Foreign keys
|
|
fks = inspector.get_foreign_keys(tbl)
|
|
result["fks"][tbl] = sorted(
|
|
(
|
|
tuple(fk["constrained_columns"]),
|
|
fk["referred_table"],
|
|
tuple(fk["referred_columns"]),
|
|
fk.get("ondelete"),
|
|
)
|
|
for fk in fks
|
|
)
|
|
|
|
# Indexes
|
|
idxs = inspector.get_indexes(tbl)
|
|
result["indexes"][tbl] = sorted(
|
|
(idx["name"], tuple(idx["column_names"]), idx.get("unique", False))
|
|
for idx in idxs
|
|
)
|
|
|
|
return result
|
|
finally:
|
|
eng.dispose()
|
|
finally:
|
|
from os import unlink
|
|
|
|
unlink(tmp.name)
|
|
|
|
|
|
def _schema_matches_baseline(database_url: str) -> bool:
|
|
"""Check whether an unmanaged DB's schema matches V1 baseline.
|
|
|
|
Compares table names, column definitions (name, nullable, type, PK),
|
|
foreign keys (constrained/referred columns, ondelete), and indexes
|
|
(name, columns, unique). SQLite type-affinity differences are
|
|
tolerated via an explicit normalization allowlist.
|
|
"""
|
|
ref = _build_reference_schema()
|
|
eng = create_engine(database_url)
|
|
try:
|
|
inspector = sa_inspect(eng)
|
|
|
|
# 1. Table names must match exactly
|
|
actual_tables = set(inspector.get_table_names())
|
|
if actual_tables != ref["tables"]:
|
|
logger.info("Table mismatch: got %s, expected %s", actual_tables, ref["tables"])
|
|
return False
|
|
|
|
for tbl in ref["tables"]:
|
|
# 2. Columns
|
|
actual_cols = sorted(
|
|
(c["name"], c.get("nullable", True), str(c["type"]), c.get("primary_key", False))
|
|
for c in inspector.get_columns(tbl)
|
|
)
|
|
if actual_cols != ref["columns"][tbl]:
|
|
logger.info("Column mismatch on %s: got %s, expected %s", tbl, actual_cols, ref["columns"][tbl])
|
|
return False
|
|
|
|
# 3. Foreign keys
|
|
actual_fks = sorted(
|
|
(
|
|
tuple(fk["constrained_columns"]),
|
|
fk["referred_table"],
|
|
tuple(fk["referred_columns"]),
|
|
fk.get("ondelete"),
|
|
)
|
|
for fk in inspector.get_foreign_keys(tbl)
|
|
)
|
|
if actual_fks != ref["fks"][tbl]:
|
|
logger.info("FK mismatch on %s: got %s, expected %s", tbl, actual_fks, ref["fks"][tbl])
|
|
return False
|
|
|
|
# 4. Indexes
|
|
actual_idxs = sorted(
|
|
(idx["name"], tuple(idx["column_names"]), idx.get("unique", False))
|
|
for idx in inspector.get_indexes(tbl)
|
|
)
|
|
if actual_idxs != ref["indexes"][tbl]:
|
|
logger.info("Index mismatch on %s: got %s, expected %s", tbl, actual_idxs, ref["indexes"][tbl])
|
|
return False
|
|
|
|
return True
|
|
finally:
|
|
eng.dispose()
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
def verify_schema_is_current(database_url: str) -> None:
|
|
"""Read-only check: confirm the DB is at ``head``.
|
|
|
|
Called by ``init_db()`` at application startup. **Never modifies the
|
|
DB.** Raises ``RuntimeError`` if the DB is not at ``head``, with a
|
|
message guiding the user to run ``python -m app.migrate``.
|
|
"""
|
|
# For SQLite file URLs, check file existence first to avoid the engine
|
|
# creating a side-effect empty file.
|
|
from sqlalchemy.engine import make_url
|
|
|
|
url = make_url(database_url)
|
|
if url.drivername.startswith("sqlite"):
|
|
db_path = url.database
|
|
if db_path and db_path != ":memory:" and not Path(db_path).exists():
|
|
raise RuntimeError(
|
|
f"Database file does not exist: {db_path}. "
|
|
"Run `python -m app.migrate` to create the schema first."
|
|
)
|
|
|
|
state = _detect_db_state(database_url)
|
|
|
|
if state == "empty":
|
|
raise RuntimeError(
|
|
"Database is empty — no tables found. "
|
|
"Run `python -m app.migrate` to create the schema first."
|
|
)
|
|
|
|
if state == "unmanaged":
|
|
raise RuntimeError(
|
|
"Database exists but has no alembic_version table (not under Alembic control). "
|
|
"Run `python -m app.migrate` to adopt it first."
|
|
)
|
|
|
|
# state == "managed" — check revision
|
|
current = _get_current_revision(database_url)
|
|
|
|
# Determine head revision from the migration scripts
|
|
cfg = _make_alembic_config(database_url)
|
|
from alembic.script import ScriptDirectory
|
|
|
|
script = ScriptDirectory.from_config(cfg)
|
|
head_rev = script.get_current_head()
|
|
|
|
if current != head_rev:
|
|
raise RuntimeError(
|
|
f"Database is at revision '{current}' but the application expects "
|
|
f"'{head_rev}'. Run `python -m app.migrate` to upgrade."
|
|
)
|
|
|
|
logger.info("Database schema verification passed (revision: %s).", current)
|
|
|
|
|
|
def run_migrations(database_url: str) -> None:
|
|
"""Execute migrations — intended for the CLI entry point.
|
|
|
|
Idempotent: safe to re-run on every deploy.
|
|
|
|
Cases:
|
|
- Empty DB → ``upgrade head``
|
|
- Unmanaged DB matching V1 baseline → ``stamp V1`` → ``upgrade head``
|
|
- Unmanaged DB NOT matching V1 baseline → **fail-close**
|
|
- Already managed → ``upgrade head`` (no-op if at head)
|
|
"""
|
|
cfg = _make_alembic_config(database_url)
|
|
state = _detect_db_state(database_url)
|
|
|
|
if state == "empty":
|
|
logger.info("Empty database detected — creating schema from scratch.")
|
|
command.upgrade(cfg, "head")
|
|
|
|
elif state == "unmanaged":
|
|
if _schema_matches_baseline(database_url):
|
|
logger.info(
|
|
"Unmanaged database matches V1 baseline — stamping %s and upgrading.",
|
|
V1_REVISION,
|
|
)
|
|
command.stamp(cfg, V1_REVISION)
|
|
command.upgrade(cfg, "head")
|
|
else:
|
|
logger.error(
|
|
"Unmanaged database schema does NOT match V1 baseline. "
|
|
"Refusing to migrate to avoid data loss."
|
|
)
|
|
raise SystemExit(
|
|
"Migration aborted: database schema does not match the "
|
|
"expected V1 baseline. Inspect the database manually."
|
|
)
|
|
|
|
else: # managed
|
|
logger.info("Database already under Alembic control — upgrading to head.")
|
|
command.upgrade(cfg, "head")
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# CLI entry point: ``python -m app.migrate``
|
|
# ------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(levelname)s [%(name)s] %(message)s",
|
|
)
|
|
from app.config import get_settings
|
|
|
|
settings = get_settings()
|
|
url = settings.database_url
|
|
logger.info("Running migrations against %s", url)
|
|
run_migrations(url)
|
|
logger.info("Migration complete.")
|