from typing import Generator from sqlalchemy import create_engine, event, text from sqlalchemy.engine import make_url from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.config import get_settings engine = None SessionLocal = sessionmaker(autocommit=False, autoflush=False) class Base(DeclarativeBase): pass def _build_engine(database_url: str): _ensure_sqlite_directory(database_url) connect_args = {"check_same_thread": False} if database_url.startswith("sqlite") else {} created_engine = create_engine(database_url, connect_args=connect_args) if database_url.startswith("sqlite"): @event.listens_for(created_engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() return created_engine def _ensure_sqlite_directory(database_url: str) -> None: if not database_url.startswith("sqlite"): return database_path = make_url(database_url).database if not database_path or database_path == ":memory:": return from pathlib import Path Path(database_path).parent.mkdir(parents=True, exist_ok=True) def configure_database(database_url: str | None = None) -> None: global engine settings = get_settings() resolved_database_url = database_url or settings.database_url if engine is not None: engine.dispose() engine = _build_engine(resolved_database_url) SessionLocal.configure(bind=engine) def get_db() -> Generator[Session, None, None]: db = SessionLocal() try: yield db finally: db.close() def init_db(database_url: str | None = None) -> None: from app import models if engine is None or database_url is not None: configure_database(database_url) Base.metadata.create_all(bind=engine) _sync_sqlite_image_columns() def _sync_sqlite_image_columns() -> None: if engine is None or engine.dialect.name != "sqlite": return image_columns = { "boxes": { "image_blob": "BLOB", "image_mime_type": "VARCHAR(50)", "image_width": "INTEGER", "image_height": "INTEGER", }, "items": { "image_blob": "BLOB", "image_mime_type": "VARCHAR(50)", "image_width": "INTEGER", "image_height": "INTEGER", }, "subitems": { "image_blob": "BLOB", "image_mime_type": "VARCHAR(50)", "image_width": "INTEGER", "image_height": "INTEGER", }, } with engine.begin() as connection: for table_name, columns in image_columns.items(): existing_columns = { row[1] for row in connection.execute(text(f"PRAGMA table_info({table_name})")) } for column_name, column_type in columns.items(): if column_name not in existing_columns: connection.execute( text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}") )