54 lines
1.3 KiB
Python
54 lines
1.3 KiB
Python
from collections.abc import Generator
|
|
from functools import lru_cache
|
|
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
|
|
|
from app.config import get_settings
|
|
|
|
|
|
class AuthBase(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
def _build_connect_args(database_url: str) -> dict[str, object]:
|
|
connect_args: dict[str, object] = {}
|
|
if database_url.startswith("sqlite"):
|
|
connect_args["check_same_thread"] = False
|
|
return connect_args
|
|
|
|
|
|
@lru_cache
|
|
def _get_auth_engine(database_url: str):
|
|
return create_engine(database_url, connect_args=_build_connect_args(database_url))
|
|
|
|
|
|
@lru_cache
|
|
def _get_auth_session_local(database_url: str):
|
|
engine = _get_auth_engine(database_url)
|
|
return sessionmaker(bind=engine, autoflush=False, autocommit=False, class_=Session)
|
|
|
|
|
|
def get_auth_engine():
|
|
settings = get_settings()
|
|
return _get_auth_engine(settings.app_database_url)
|
|
|
|
|
|
def get_auth_session_local():
|
|
settings = get_settings()
|
|
return _get_auth_session_local(settings.app_database_url)
|
|
|
|
|
|
def reset_auth_db_caches() -> None:
|
|
_get_auth_session_local.cache_clear()
|
|
_get_auth_engine.cache_clear()
|
|
|
|
|
|
def get_auth_db_session() -> Generator[Session, None, None]:
|
|
session_local = get_auth_session_local()
|
|
session = session_local()
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|