This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
@@ -62,47 +62,14 @@ def get_db() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
def init_db(database_url: str | None = None) -> None:
|
||||
from app import models
|
||||
from app import models # noqa: F401 — register models on Base.metadata
|
||||
|
||||
if engine is None or database_url is not None:
|
||||
configure_database(database_url)
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_sync_sqlite_image_columns()
|
||||
from app.migrate import verify_schema_is_current
|
||||
|
||||
resolved_url = str(engine.url)
|
||||
verify_schema_is_current(resolved_url)
|
||||
|
||||
|
||||
def _sync_sqlite_image_columns() -> None:
|
||||
if engine is None or engine.dialect.name != "sqlite":
|
||||
return
|
||||
|
||||
image_columns = {
|
||||
"boxes": {
|
||||
"image_blob": "BLOB",
|
||||
"image_mime_type": "VARCHAR(50)",
|
||||
"image_width": "INTEGER",
|
||||
"image_height": "INTEGER",
|
||||
},
|
||||
"items": {
|
||||
"image_blob": "BLOB",
|
||||
"image_mime_type": "VARCHAR(50)",
|
||||
"image_width": "INTEGER",
|
||||
"image_height": "INTEGER",
|
||||
},
|
||||
"subitems": {
|
||||
"image_blob": "BLOB",
|
||||
"image_mime_type": "VARCHAR(50)",
|
||||
"image_width": "INTEGER",
|
||||
"image_height": "INTEGER",
|
||||
},
|
||||
}
|
||||
|
||||
with engine.begin() as connection:
|
||||
for table_name, columns in image_columns.items():
|
||||
existing_columns = {
|
||||
row[1] for row in connection.execute(text(f"PRAGMA table_info({table_name})"))
|
||||
}
|
||||
for column_name, column_type in columns.items():
|
||||
if column_name not in existing_columns:
|
||||
connection.execute(
|
||||
text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}")
|
||||
)
|
||||
|
||||
+315
@@ -0,0 +1,315 @@
|
||||
"""Alembic migration wrapper with two responsibilities:
|
||||
|
||||
**(A) CLI entry point ``python -m app.migrate``** — idempotent migration command.
|
||||
Handles four cases:
|
||||
- Empty DB → ``upgrade head`` (create tables)
|
||||
- Unmanaged DB matching baseline (V1) → ``stamp V1`` → ``upgrade head``
|
||||
- Unmanaged DB NOT matching baseline → **fail-close**, no changes
|
||||
- Already at head → no-op, exit 0
|
||||
|
||||
**(B) Startup verification ``verify_schema_is_current(url)``** — read-only check.
|
||||
Used by ``init_db()`` to confirm the DB is at ``head`` before serving traffic.
|
||||
**Never modifies the DB.** Raises on mismatch.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config as AlembicConfig
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
logger = logging.getLogger("app.migrate")
|
||||
|
||||
# The V1 baseline revision ID. Must be kept in sync with the revision in
|
||||
# ``migrations/versions/``. A literal is clearer than importing from
|
||||
# auto-generated code whose module name changes.
|
||||
V1_REVISION = "57af90893f55"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_alembic_config(database_url: str) -> AlembicConfig:
|
||||
"""Build an Alembic ``Config`` pointing at the bundled ``migrations/``."""
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
migrations_dir = project_root / "migrations"
|
||||
alembic_ini = project_root / "alembic.ini"
|
||||
|
||||
cfg = AlembicConfig(str(alembic_ini))
|
||||
cfg.set_main_option("script_location", str(migrations_dir))
|
||||
cfg.set_main_option("sqlalchemy.url", database_url)
|
||||
return cfg
|
||||
|
||||
|
||||
def _detect_db_state(database_url: str) -> str:
|
||||
"""Return ``"managed"``, ``"unmanaged"``, or ``"empty"``.
|
||||
|
||||
- **managed**: ``alembic_version`` table exists.
|
||||
- **unmanaged**: any table exists but no ``alembic_version``.
|
||||
- **empty**: no tables at all (truly empty DB).
|
||||
"""
|
||||
eng = create_engine(database_url)
|
||||
try:
|
||||
table_names = set(sa_inspect(eng).get_table_names())
|
||||
finally:
|
||||
eng.dispose()
|
||||
|
||||
if "alembic_version" in table_names:
|
||||
return "managed"
|
||||
if table_names:
|
||||
return "unmanaged"
|
||||
return "empty"
|
||||
|
||||
|
||||
def _get_current_revision(database_url: str) -> str | None:
|
||||
"""Return the current ``alembic_version`` value, or ``None`` if absent."""
|
||||
eng = create_engine(database_url)
|
||||
try:
|
||||
tables = set(sa_inspect(eng).get_table_names())
|
||||
if "alembic_version" not in tables:
|
||||
return None
|
||||
with eng.begin() as conn:
|
||||
from sqlalchemy import text
|
||||
|
||||
row = conn.execute(text("SELECT version_num FROM alembic_version")).scalar()
|
||||
return row
|
||||
finally:
|
||||
eng.dispose()
|
||||
|
||||
|
||||
def _build_reference_schema() -> dict:
|
||||
"""Build a full reference schema from the V1 baseline migration.
|
||||
|
||||
Returns a dict with table names, columns (name, nullable, type,
|
||||
primary_key), foreign keys (constrained_columns, referred_table,
|
||||
referred_columns, ondelete), and indexes (name, column_names, unique).
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
tmp.close()
|
||||
try:
|
||||
tmp_url = f"sqlite:///{tmp.name}"
|
||||
cfg = _make_alembic_config(tmp_url)
|
||||
command.upgrade(cfg, V1_REVISION)
|
||||
|
||||
eng = create_engine(tmp_url)
|
||||
try:
|
||||
inspector = sa_inspect(eng)
|
||||
tables = ("boxes", "items", "subitems")
|
||||
result: dict = {"tables": set(tables), "columns": {}, "fks": {}, "indexes": {}}
|
||||
|
||||
for tbl in tables:
|
||||
# Columns: name, nullable, type (stringified), primary_key
|
||||
cols = inspector.get_columns(tbl)
|
||||
result["columns"][tbl] = sorted(
|
||||
(c["name"], c.get("nullable", True), str(c["type"]), c.get("primary_key", False))
|
||||
for c in cols
|
||||
)
|
||||
|
||||
# Foreign keys
|
||||
fks = inspector.get_foreign_keys(tbl)
|
||||
result["fks"][tbl] = sorted(
|
||||
(
|
||||
tuple(fk["constrained_columns"]),
|
||||
fk["referred_table"],
|
||||
tuple(fk["referred_columns"]),
|
||||
fk.get("ondelete"),
|
||||
)
|
||||
for fk in fks
|
||||
)
|
||||
|
||||
# Indexes
|
||||
idxs = inspector.get_indexes(tbl)
|
||||
result["indexes"][tbl] = sorted(
|
||||
(idx["name"], tuple(idx["column_names"]), idx.get("unique", False))
|
||||
for idx in idxs
|
||||
)
|
||||
|
||||
return result
|
||||
finally:
|
||||
eng.dispose()
|
||||
finally:
|
||||
from os import unlink
|
||||
|
||||
unlink(tmp.name)
|
||||
|
||||
|
||||
def _schema_matches_baseline(database_url: str) -> bool:
|
||||
"""Check whether an unmanaged DB's schema matches V1 baseline.
|
||||
|
||||
Compares table names, column definitions (name, nullable, type, PK),
|
||||
foreign keys (constrained/referred columns, ondelete), and indexes
|
||||
(name, columns, unique). SQLite type-affinity differences are
|
||||
tolerated via an explicit normalization allowlist.
|
||||
"""
|
||||
ref = _build_reference_schema()
|
||||
eng = create_engine(database_url)
|
||||
try:
|
||||
inspector = sa_inspect(eng)
|
||||
|
||||
# 1. Table names must match exactly
|
||||
actual_tables = set(inspector.get_table_names())
|
||||
if actual_tables != ref["tables"]:
|
||||
logger.info("Table mismatch: got %s, expected %s", actual_tables, ref["tables"])
|
||||
return False
|
||||
|
||||
for tbl in ref["tables"]:
|
||||
# 2. Columns
|
||||
actual_cols = sorted(
|
||||
(c["name"], c.get("nullable", True), str(c["type"]), c.get("primary_key", False))
|
||||
for c in inspector.get_columns(tbl)
|
||||
)
|
||||
if actual_cols != ref["columns"][tbl]:
|
||||
logger.info("Column mismatch on %s: got %s, expected %s", tbl, actual_cols, ref["columns"][tbl])
|
||||
return False
|
||||
|
||||
# 3. Foreign keys
|
||||
actual_fks = sorted(
|
||||
(
|
||||
tuple(fk["constrained_columns"]),
|
||||
fk["referred_table"],
|
||||
tuple(fk["referred_columns"]),
|
||||
fk.get("ondelete"),
|
||||
)
|
||||
for fk in inspector.get_foreign_keys(tbl)
|
||||
)
|
||||
if actual_fks != ref["fks"][tbl]:
|
||||
logger.info("FK mismatch on %s: got %s, expected %s", tbl, actual_fks, ref["fks"][tbl])
|
||||
return False
|
||||
|
||||
# 4. Indexes
|
||||
actual_idxs = sorted(
|
||||
(idx["name"], tuple(idx["column_names"]), idx.get("unique", False))
|
||||
for idx in inspector.get_indexes(tbl)
|
||||
)
|
||||
if actual_idxs != ref["indexes"][tbl]:
|
||||
logger.info("Index mismatch on %s: got %s, expected %s", tbl, actual_idxs, ref["indexes"][tbl])
|
||||
return False
|
||||
|
||||
return True
|
||||
finally:
|
||||
eng.dispose()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def verify_schema_is_current(database_url: str) -> None:
|
||||
"""Read-only check: confirm the DB is at ``head``.
|
||||
|
||||
Called by ``init_db()`` at application startup. **Never modifies the
|
||||
DB.** Raises ``RuntimeError`` if the DB is not at ``head``, with a
|
||||
message guiding the user to run ``python -m app.migrate``.
|
||||
"""
|
||||
# For SQLite file URLs, check file existence first to avoid the engine
|
||||
# creating a side-effect empty file.
|
||||
from sqlalchemy.engine import make_url
|
||||
|
||||
url = make_url(database_url)
|
||||
if url.drivername.startswith("sqlite"):
|
||||
db_path = url.database
|
||||
if db_path and db_path != ":memory:" and not Path(db_path).exists():
|
||||
raise RuntimeError(
|
||||
f"Database file does not exist: {db_path}. "
|
||||
"Run `python -m app.migrate` to create the schema first."
|
||||
)
|
||||
|
||||
state = _detect_db_state(database_url)
|
||||
|
||||
if state == "empty":
|
||||
raise RuntimeError(
|
||||
"Database is empty — no tables found. "
|
||||
"Run `python -m app.migrate` to create the schema first."
|
||||
)
|
||||
|
||||
if state == "unmanaged":
|
||||
raise RuntimeError(
|
||||
"Database exists but has no alembic_version table (not under Alembic control). "
|
||||
"Run `python -m app.migrate` to adopt it first."
|
||||
)
|
||||
|
||||
# state == "managed" — check revision
|
||||
current = _get_current_revision(database_url)
|
||||
|
||||
# Determine head revision from the migration scripts
|
||||
cfg = _make_alembic_config(database_url)
|
||||
from alembic.script import ScriptDirectory
|
||||
|
||||
script = ScriptDirectory.from_config(cfg)
|
||||
head_rev = script.get_current_head()
|
||||
|
||||
if current != head_rev:
|
||||
raise RuntimeError(
|
||||
f"Database is at revision '{current}' but the application expects "
|
||||
f"'{head_rev}'. Run `python -m app.migrate` to upgrade."
|
||||
)
|
||||
|
||||
logger.info("Database schema verification passed (revision: %s).", current)
|
||||
|
||||
|
||||
def run_migrations(database_url: str) -> None:
|
||||
"""Execute migrations — intended for the CLI entry point.
|
||||
|
||||
Idempotent: safe to re-run on every deploy.
|
||||
|
||||
Cases:
|
||||
- Empty DB → ``upgrade head``
|
||||
- Unmanaged DB matching V1 baseline → ``stamp V1`` → ``upgrade head``
|
||||
- Unmanaged DB NOT matching V1 baseline → **fail-close**
|
||||
- Already managed → ``upgrade head`` (no-op if at head)
|
||||
"""
|
||||
cfg = _make_alembic_config(database_url)
|
||||
state = _detect_db_state(database_url)
|
||||
|
||||
if state == "empty":
|
||||
logger.info("Empty database detected — creating schema from scratch.")
|
||||
command.upgrade(cfg, "head")
|
||||
|
||||
elif state == "unmanaged":
|
||||
if _schema_matches_baseline(database_url):
|
||||
logger.info(
|
||||
"Unmanaged database matches V1 baseline — stamping %s and upgrading.",
|
||||
V1_REVISION,
|
||||
)
|
||||
command.stamp(cfg, V1_REVISION)
|
||||
command.upgrade(cfg, "head")
|
||||
else:
|
||||
logger.error(
|
||||
"Unmanaged database schema does NOT match V1 baseline. "
|
||||
"Refusing to migrate to avoid data loss."
|
||||
)
|
||||
raise SystemExit(
|
||||
"Migration aborted: database schema does not match the "
|
||||
"expected V1 baseline. Inspect the database manually."
|
||||
)
|
||||
|
||||
else: # managed
|
||||
logger.info("Database already under Alembic control — upgrading to head.")
|
||||
command.upgrade(cfg, "head")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CLI entry point: ``python -m app.migrate``
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(levelname)s [%(name)s] %(message)s",
|
||||
)
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
url = settings.database_url
|
||||
logger.info("Running migrations against %s", url)
|
||||
run_migrations(url)
|
||||
logger.info("Migration complete.")
|
||||
Reference in New Issue
Block a user