from collections.abc import Generator from functools import lru_cache from sqlalchemy import create_engine, event from sqlalchemy.engine import Engine from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.config import get_settings class Base(DeclarativeBase): pass def _build_connect_args(database_url: str) -> dict[str, object]: connect_args: dict[str, object] = {} if database_url.startswith("sqlite"): connect_args["check_same_thread"] = False return connect_args @lru_cache def _get_engine(database_url: str) -> Engine: engine = create_engine(database_url, connect_args=_build_connect_args(database_url)) if database_url.startswith("sqlite"): @event.listens_for(engine, "connect") def _enable_sqlite_wal(dbapi_connection, _connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA journal_mode=WAL") cursor.close() return engine @lru_cache def _get_session_local(database_url: str) -> sessionmaker: engine = _get_engine(database_url) return sessionmaker(bind=engine, autoflush=False, autocommit=False, class_=Session) def get_engine() -> Engine: return _get_engine(get_settings().app_database_url) def get_session_local() -> sessionmaker: return _get_session_local(get_settings().app_database_url) def reset_db_caches() -> None: _get_session_local.cache_clear() _get_engine.cache_clear() def get_db_session() -> Generator[Session, None, None]: session_local = get_session_local() session = session_local() try: yield session finally: session.close()