diff --git a/alembic_app/env.py b/alembic_app/env.py index 00d6bea..f198417 100644 --- a/alembic_app/env.py +++ b/alembic_app/env.py @@ -3,11 +3,13 @@ from logging.config import fileConfig from alembic import context from sqlalchemy import engine_from_config, pool -from app.auth_db import AuthBase from app.config import get_settings +from app.db import Base from app.models.config import AppConfigEntry # noqa: F401 from app.models.auth import AuthSession, AuthUser # noqa: F401 from app.models.public_ip import PublicIPHistory, PublicIPState # noqa: F401 +from app.models.location import Location # noqa: F401 +from app.models.poo import PooRecord # noqa: F401 config = context.config @@ -19,7 +21,7 @@ configured_url = config.get_main_option("sqlalchemy.url") if not configured_url or configured_url == "sqlite:///./data/app.db": config.set_main_option("sqlalchemy.url", settings.app_database_url) -target_metadata = AuthBase.metadata +target_metadata = Base.metadata def run_migrations_offline() -> None: diff --git a/alembic_location/env.py b/alembic_location/env.py index 5b2d901..05b9217 100644 --- a/alembic_location/env.py +++ b/alembic_location/env.py @@ -5,7 +5,7 @@ from sqlalchemy import engine_from_config, pool from app.config import get_settings from app.models import Location # noqa: F401 -from app.models.base import Base +from app.db import Base config = context.config diff --git a/alembic_poo/env.py b/alembic_poo/env.py index 44cb0b9..98b528f 100644 --- a/alembic_poo/env.py +++ b/alembic_poo/env.py @@ -5,7 +5,7 @@ from sqlalchemy import engine_from_config, pool from app.config import get_settings from app.models.poo import PooRecord # noqa: F401 -from app.poo_db import PooBase +from app.db import Base config = context.config @@ -17,7 +17,7 @@ configured_url = config.get_main_option("sqlalchemy.url") if not configured_url or configured_url == "sqlite:///./data/pooRecorder.db": config.set_main_option("sqlalchemy.url", settings.poo_database_url) -target_metadata = PooBase.metadata +target_metadata = Base.metadata def run_migrations_offline() -> None: diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py index d9603e2..d80846f 100644 --- a/app/api/routes/auth.py +++ b/app/api/routes/auth.py @@ -7,7 +7,7 @@ from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session from app.config import Settings -from app.dependencies import get_app_settings, get_auth_db, get_current_auth_session +from app.dependencies import get_app_settings, get_db, get_current_auth_session from app.services.auth import ( AuthenticatedSession, authenticate_user, @@ -57,7 +57,7 @@ def login_submit( username: str = Form(), password: str = Form(), csrf_token: str = Form(), - session: Session = Depends(get_auth_db), + session: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), ) -> Response: cookie_csrf_token = request.cookies.get(LOGIN_CSRF_COOKIE_NAME) @@ -102,7 +102,7 @@ def change_password_submit( new_password: str = Form(), confirm_password: str = Form(), csrf_token: str = Form(), - session: Session = Depends(get_auth_db), + session: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), current_auth: AuthenticatedSession | None = Depends(get_current_auth_session), ) -> Response: @@ -151,7 +151,7 @@ def change_password_submit( def logout( request: Request, csrf_token: str = Form(), - session: Session = Depends(get_auth_db), + session: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), current_auth: AuthenticatedSession | None = Depends(get_current_auth_session), ) -> RedirectResponse: diff --git a/app/api/routes/homeassistant.py b/app/api/routes/homeassistant.py index ccee1f8..703df37 100644 --- a/app/api/routes/homeassistant.py +++ b/app/api/routes/homeassistant.py @@ -11,7 +11,6 @@ from app.dependencies import ( get_app_settings, get_db, get_homeassistant_client, - get_poo_db, get_ticktick_client, ) from app.integrations.homeassistant import ( @@ -36,7 +35,6 @@ INTERNAL_SERVER_ERROR_MESSAGE = "internal server error" async def publish_from_homeassistant( request: Request, db: Session = Depends(get_db), - poo_db: Session = Depends(get_poo_db), settings: Settings = Depends(get_app_settings), homeassistant_client: HomeAssistantClient = Depends(get_homeassistant_client), ticktick_client: TickTickClient = Depends(get_ticktick_client), @@ -49,7 +47,7 @@ async def publish_from_homeassistant( db, envelope, ticktick_client=ticktick_client, - poo_session=poo_db, + poo_session=db, settings=settings, homeassistant_client=homeassistant_client, ) diff --git a/app/api/routes/pages.py b/app/api/routes/pages.py index 4b474cb..bbd2594 100644 --- a/app/api/routes/pages.py +++ b/app/api/routes/pages.py @@ -6,7 +6,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse, Response from fastapi.templating import Jinja2Templates from app.config import Settings, get_settings -from app.dependencies import get_app_settings, get_auth_db, get_current_auth_session +from app.dependencies import get_app_settings, get_db, get_current_auth_session from app.services.auth import AuthenticatedSession from app.services.config_page import ( ConfigSaveError, @@ -100,7 +100,7 @@ def admin_redirect( @router.get("/config", response_class=HTMLResponse) def config_page( request: Request, - auth_db_session: Session = Depends(get_auth_db), + auth_db_session: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), current_auth: AuthenticatedSession | None = Depends(get_current_auth_session), ) -> Response: @@ -129,7 +129,7 @@ def config_page( @router.post("/config", response_class=HTMLResponse) async def config_submit( request: Request, - auth_db_session: Session = Depends(get_auth_db), + auth_db_session: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), current_auth: AuthenticatedSession | None = Depends(get_current_auth_session), ) -> Response: @@ -189,7 +189,7 @@ async def config_submit( @router.post("/config/smtp/test", response_class=HTMLResponse) async def smtp_test_submit( request: Request, - auth_db_session: Session = Depends(get_auth_db), + auth_db_session: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), current_auth: AuthenticatedSession | None = Depends(get_current_auth_session), ) -> Response: diff --git a/app/api/routes/poo.py b/app/api/routes/poo.py index 451741d..4473e30 100644 --- a/app/api/routes/poo.py +++ b/app/api/routes/poo.py @@ -7,7 +7,7 @@ from pydantic import ValidationError from sqlalchemy.orm import Session from app.config import Settings -from app.dependencies import get_app_settings, get_homeassistant_client, get_poo_db +from app.dependencies import get_app_settings, get_homeassistant_client, get_db from app.integrations.homeassistant import HomeAssistantClient from app.schemas.poo import PooRecordRequest from app.services.poo import publish_latest_poo_status, record_poo @@ -21,7 +21,7 @@ INTERNAL_SERVER_ERROR_MESSAGE = "internal server error" @router.post("/poo/record") async def create_poo_record( request: Request, - db: Session = Depends(get_poo_db), + db: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), homeassistant_client: HomeAssistantClient = Depends(get_homeassistant_client), ) -> Response: @@ -56,7 +56,7 @@ async def create_poo_record( @router.get("/poo/latest") def notify_latest_poo( - db: Session = Depends(get_poo_db), + db: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), homeassistant_client: HomeAssistantClient = Depends(get_homeassistant_client), ) -> Response: diff --git a/app/api/routes/public_ip.py b/app/api/routes/public_ip.py index 766525f..6ebc0b4 100644 --- a/app/api/routes/public_ip.py +++ b/app/api/routes/public_ip.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from app.dependencies import get_auth_db, get_current_auth_session +from app.dependencies import get_db, get_current_auth_session from app.schemas.public_ip import PublicIPCheckResponse from app.config import get_settings from app.services.auth import AuthenticatedSession @@ -12,7 +12,7 @@ router = APIRouter(tags=["public-ip"]) @router.get("/public-ip/check", response_model=PublicIPCheckResponse) def run_public_ip_check( - session: Session = Depends(get_auth_db), + session: Session = Depends(get_db), current_auth: AuthenticatedSession | None = Depends(get_current_auth_session), ) -> PublicIPCheckResponse: if current_auth is None: diff --git a/app/api/routes/ticktick.py b/app/api/routes/ticktick.py index b728108..9a1a417 100644 --- a/app/api/routes/ticktick.py +++ b/app/api/routes/ticktick.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from app.config import Settings from app.dependencies import ( get_app_settings, - get_auth_db, + get_db, get_current_auth_session, get_ticktick_client, ) @@ -39,7 +39,7 @@ def start_ticktick_auth( @router.get("/ticktick/auth/code") def handle_ticktick_auth_code( request: Request, - auth_db_session: Session = Depends(get_auth_db), + auth_db_session: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), ticktick_client: TickTickClient = Depends(get_ticktick_client), ) -> Response: diff --git a/app/auth_db.py b/app/auth_db.py deleted file mode 100644 index 41dcd1f..0000000 --- a/app/auth_db.py +++ /dev/null @@ -1,53 +0,0 @@ -from collections.abc import Generator -from functools import lru_cache - -from sqlalchemy import create_engine -from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker - -from app.config import get_settings - - -class AuthBase(DeclarativeBase): - pass - - -def _build_connect_args(database_url: str) -> dict[str, object]: - connect_args: dict[str, object] = {} - if database_url.startswith("sqlite"): - connect_args["check_same_thread"] = False - return connect_args - - -@lru_cache -def _get_auth_engine(database_url: str): - return create_engine(database_url, connect_args=_build_connect_args(database_url)) - - -@lru_cache -def _get_auth_session_local(database_url: str): - engine = _get_auth_engine(database_url) - return sessionmaker(bind=engine, autoflush=False, autocommit=False, class_=Session) - - -def get_auth_engine(): - settings = get_settings() - return _get_auth_engine(settings.app_database_url) - - -def get_auth_session_local(): - settings = get_settings() - return _get_auth_session_local(settings.app_database_url) - - -def reset_auth_db_caches() -> None: - _get_auth_session_local.cache_clear() - _get_auth_engine.cache_clear() - - -def get_auth_db_session() -> Generator[Session, None, None]: - session_local = get_auth_session_local() - session = session_local() - try: - yield session - finally: - session.close() diff --git a/app/db.py b/app/db.py index c8d94a9..ed8068d 100644 --- a/app/db.py +++ b/app/db.py @@ -1,6 +1,8 @@ from collections.abc import Generator +from functools import lru_cache -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event +from sqlalchemy.engine import Engine from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.config import get_settings @@ -10,18 +12,49 @@ class Base(DeclarativeBase): pass -settings = get_settings() +def _build_connect_args(database_url: str) -> dict[str, object]: + connect_args: dict[str, object] = {} + if database_url.startswith("sqlite"): + connect_args["check_same_thread"] = False + return connect_args -connect_args: dict[str, object] = {} -if settings.location_database_url.startswith("sqlite"): - connect_args["check_same_thread"] = False -engine = create_engine(settings.location_database_url, connect_args=connect_args) -SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, class_=Session) +@lru_cache +def _get_engine(database_url: str) -> Engine: + engine = create_engine(database_url, connect_args=_build_connect_args(database_url)) + if database_url.startswith("sqlite"): + + @event.listens_for(engine, "connect") + def _enable_sqlite_wal(dbapi_connection, _connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.close() + + return engine + + +@lru_cache +def _get_session_local(database_url: str) -> sessionmaker: + engine = _get_engine(database_url) + return sessionmaker(bind=engine, autoflush=False, autocommit=False, class_=Session) + + +def get_engine() -> Engine: + return _get_engine(get_settings().app_database_url) + + +def get_session_local() -> sessionmaker: + return _get_session_local(get_settings().app_database_url) + + +def reset_db_caches() -> None: + _get_session_local.cache_clear() + _get_engine.cache_clear() def get_db_session() -> Generator[Session, None, None]: - session = SessionLocal() + session_local = get_session_local() + session = session_local() try: yield session finally: diff --git a/app/dependencies.py b/app/dependencies.py index ed4f3f0..570fb2b 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -3,30 +3,20 @@ from collections.abc import Generator from fastapi import Depends, Request from sqlalchemy.orm import Session -from app.auth_db import get_auth_db_session from app.config import Settings, get_settings from app.db import get_db_session from app.integrations.homeassistant import HomeAssistantClient from app.integrations.ticktick import TickTickClient -from app.poo_db import get_poo_db_session from app.services.auth import AuthenticatedSession, get_authenticated_session from app.services.config_page import build_runtime_settings -def get_auth_db() -> Generator[Session, None, None]: - yield from get_auth_db_session() - - -def get_app_settings(session: Session = Depends(get_auth_db)) -> Settings: - return build_runtime_settings(session, get_settings()) - - def get_db() -> Generator[Session, None, None]: yield from get_db_session() -def get_poo_db() -> Generator[Session, None, None]: - yield from get_poo_db_session() +def get_app_settings(session: Session = Depends(get_db)) -> Settings: + return build_runtime_settings(session, get_settings()) def get_homeassistant_client(settings: Settings = Depends(get_app_settings)) -> HomeAssistantClient: @@ -39,7 +29,7 @@ def get_ticktick_client(settings: Settings = Depends(get_app_settings)) -> TickT def get_current_auth_session( request: Request, - session: Session = Depends(get_auth_db), + session: Session = Depends(get_db), settings: Settings = Depends(get_app_settings), ) -> AuthenticatedSession | None: raw_token = request.cookies.get(settings.auth_session_cookie_name) diff --git a/app/main.py b/app/main.py index dd8a9ec..3e0647e 100644 --- a/app/main.py +++ b/app/main.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import Session from app import models # noqa: F401 from app.api.routes.auth import router as auth_router from app.api.routes import pages, status -import app.auth_db as auth_db +from app.db import get_session_local from app.api.routes.homeassistant import router as homeassistant_router from app.api.routes.location import router as location_router from app.api.routes.poo import router as poo_router @@ -26,7 +26,7 @@ from scripts.poo_db_adopt import PooDatabaseAdoptionError, validate_poo_runtime_ def _run_scheduled_public_ip_check() -> None: - session_local = auth_db.get_auth_session_local() + session_local = get_session_local() session: Session = session_local() try: check_public_ipv4_and_notify(session, bootstrap_settings=get_settings()) @@ -35,7 +35,7 @@ def _run_scheduled_public_ip_check() -> None: def ensure_auth_db_ready() -> None: - session_local = auth_db.get_auth_session_local() + session_local = get_session_local() session: Session = session_local() try: validate_app_runtime_db(get_settings().app_database_url) diff --git a/app/models/__init__.py b/app/models/__init__.py index 24d4862..93e8e82 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -3,6 +3,7 @@ from app.models.auth import AuthSession, AuthUser from app.models.config import AppConfigEntry from app.models.location import Location +from app.models.poo import PooRecord from app.models.public_ip import PublicIPHistory, PublicIPState __all__ = [ @@ -10,6 +11,7 @@ __all__ = [ "AuthSession", "AuthUser", "Location", + "PooRecord", "PublicIPHistory", "PublicIPState", ] diff --git a/app/models/auth.py b/app/models/auth.py index 3284913..08da8e5 100644 --- a/app/models/auth.py +++ b/app/models/auth.py @@ -3,10 +3,10 @@ from datetime import datetime from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String from sqlalchemy.orm import Mapped, mapped_column, relationship -from app.auth_db import AuthBase +from app.db import Base -class AuthUser(AuthBase): +class AuthUser(Base): __tablename__ = "auth_users" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) @@ -19,7 +19,7 @@ class AuthUser(AuthBase): sessions: Mapped[list["AuthSession"]] = relationship(back_populates="user") -class AuthSession(AuthBase): +class AuthSession(Base): __tablename__ = "auth_sessions" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) diff --git a/app/models/base.py b/app/models/base.py deleted file mode 100644 index b852be4..0000000 --- a/app/models/base.py +++ /dev/null @@ -1,4 +0,0 @@ -from app.db import Base - -__all__ = ["Base"] - diff --git a/app/models/config.py b/app/models/config.py index 31c0dff..ec7d504 100644 --- a/app/models/config.py +++ b/app/models/config.py @@ -3,10 +3,10 @@ from datetime import datetime from sqlalchemy import DateTime, Integer, String from sqlalchemy.orm import Mapped, mapped_column -from app.auth_db import AuthBase +from app.db import Base -class AppConfigEntry(AuthBase): +class AppConfigEntry(Base): __tablename__ = "app_config" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) diff --git a/app/models/poo.py b/app/models/poo.py index 6476fd8..2083a64 100644 --- a/app/models/poo.py +++ b/app/models/poo.py @@ -1,10 +1,10 @@ from sqlalchemy import Float, String from sqlalchemy.orm import Mapped, mapped_column -from app.poo_db import PooBase +from app.db import Base -class PooRecord(PooBase): +class PooRecord(Base): __tablename__ = "poo_records" timestamp: Mapped[str] = mapped_column(String, primary_key=True) diff --git a/app/models/public_ip.py b/app/models/public_ip.py index a88fd4e..589a237 100644 --- a/app/models/public_ip.py +++ b/app/models/public_ip.py @@ -3,10 +3,10 @@ from datetime import datetime from sqlalchemy import DateTime, Integer, String from sqlalchemy.orm import Mapped, mapped_column -from app.auth_db import AuthBase +from app.db import Base -class PublicIPState(AuthBase): +class PublicIPState(Base): __tablename__ = "public_ip_state" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) @@ -20,7 +20,7 @@ class PublicIPState(AuthBase): last_provider: Mapped[str | None] = mapped_column(String(64), nullable=True) -class PublicIPHistory(AuthBase): +class PublicIPHistory(Base): __tablename__ = "public_ip_history" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) diff --git a/app/poo_db.py b/app/poo_db.py deleted file mode 100644 index 3fdda48..0000000 --- a/app/poo_db.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Generator - -from sqlalchemy import create_engine -from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker - -from app.config import get_settings - - -class PooBase(DeclarativeBase): - pass - - -settings = get_settings() - -connect_args: dict[str, object] = {} -if settings.poo_database_url.startswith("sqlite"): - connect_args["check_same_thread"] = False - -poo_engine = create_engine(settings.poo_database_url, connect_args=connect_args) -PooSessionLocal = sessionmaker(bind=poo_engine, autoflush=False, autocommit=False, class_=Session) - - -def get_poo_db_session() -> Generator[Session, None, None]: - session = PooSessionLocal() - try: - yield session - finally: - session.close() diff --git a/app/services/config_page.py b/app/services/config_page.py index db7a450..956417f 100644 --- a/app/services/config_page.py +++ b/app/services/config_page.py @@ -7,7 +7,7 @@ from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session -from app.auth_db import reset_auth_db_caches +from app.db import reset_db_caches from app.config import Settings, get_settings from app.models.config import AppConfigEntry @@ -127,7 +127,7 @@ def sync_app_hostname_from_bootstrap(session: Session, bootstrap_settings: Setti current_values["APP_HOSTNAME"] = bootstrap_hostname _persist_config_values(session, current_values) get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def build_runtime_settings(session: Session, bootstrap_settings: Settings) -> Settings: @@ -184,7 +184,7 @@ def save_config_updates(session: Session, form_data: dict[str, str], bootstrap_s _validate_config_values(merged_values, bootstrap_settings) _persist_config_values(session, merged_values) get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def save_config_value( @@ -199,7 +199,7 @@ def save_config_value( _validate_config_values(current_values, bootstrap_settings) _persist_config_values(session, current_values) get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def is_ticktick_oauth_ready(settings: Settings) -> bool: diff --git a/tests/conftest.py b/tests/conftest.py index 948661f..270284b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,8 @@ from alembic import command from alembic.config import Config from fastapi.testclient import TestClient from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from app.auth_db import reset_auth_db_caches -import app.db as app_db +from app.db import reset_db_caches from app.config import get_settings from app.main import create_app @@ -47,7 +45,7 @@ def test_database_urls(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("AUTH_BOOTSTRAP_PASSWORD", "test-password") monkeypatch.setenv("AUTH_COOKIE_SECURE_OVERRIDE", "false") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() try: yield { @@ -60,7 +58,7 @@ def test_database_urls(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): } finally: get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() @pytest.fixture @@ -79,10 +77,10 @@ def ready_poo_database(test_database_urls): def auth_database(test_database_urls, monkeypatch: pytest.MonkeyPatch): database_url = test_database_urls["app_url"] command.upgrade(_make_app_alembic_config(database_url), "head") - reset_auth_db_caches() + reset_db_caches() yield test_database_urls - reset_auth_db_caches() + reset_db_caches() @pytest.fixture @@ -97,46 +95,20 @@ def client(app): @pytest.fixture -def location_client( - ready_location_database, - ready_poo_database, - auth_database, - monkeypatch: pytest.MonkeyPatch, -): - database_url = ready_location_database["location_url"] - - engine = create_engine(database_url, connect_args={"check_same_thread": False}) - session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False) - - monkeypatch.setattr(app_db, "engine", engine) - monkeypatch.setattr(app_db, "SessionLocal", session_local) - +def location_client(ready_location_database, ready_poo_database, auth_database): + app_url = auth_database["app_url"] + engine = create_engine(app_url, connect_args={"check_same_thread": False}) fastapi_app = create_app() with TestClient(fastapi_app) as client: yield client, engine - engine.dispose() @pytest.fixture -def poo_client( - ready_location_database, - ready_poo_database, - auth_database, - monkeypatch: pytest.MonkeyPatch, -): - database_url = ready_poo_database["poo_url"] - - engine = create_engine(database_url, connect_args={"check_same_thread": False}) - session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False) - - import app.poo_db as poo_db - - monkeypatch.setattr(poo_db, "poo_engine", engine) - monkeypatch.setattr(poo_db, "PooSessionLocal", session_local) - +def poo_client(ready_location_database, ready_poo_database, auth_database): + app_url = auth_database["app_url"] + engine = create_engine(app_url, connect_args={"check_same_thread": False}) fastapi_app = create_app() with TestClient(fastapi_app) as client: yield client, engine - engine.dispose() diff --git a/tests/test_app.py b/tests/test_app.py index 05a57dc..2857032 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -5,7 +5,7 @@ import pytest from alembic import command from fastapi.testclient import TestClient -from app.auth_db import reset_auth_db_caches +from app.db import reset_db_caches from app.config import get_settings from app.main import create_app from scripts.app_db_adopt import APP_BASELINE_REVISION, adopt_or_initialize_app_db @@ -49,7 +49,7 @@ def test_app_start_fails_when_app_db_missing(tmp_path, monkeypatch: pytest.Monke monkeypatch.setenv("LOCATION_DATABASE_URL", f"sqlite:///{location_database_path}") monkeypatch.setenv("POO_DATABASE_URL", f"sqlite:///{poo_database_path}") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() app = create_app() with pytest.raises(RuntimeError, match="Run 'python scripts/app_db_adopt.py' first"): @@ -58,7 +58,7 @@ def test_app_start_fails_when_app_db_missing(tmp_path, monkeypatch: pytest.Monke assert not missing_app_path.exists() get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def test_app_db_adoption_initializes_new_database(tmp_path) -> None: @@ -108,7 +108,7 @@ def test_app_start_seeds_missing_config_from_env_without_overwriting_existing_va monkeypatch.setenv("LOCATION_DATABASE_URL", f"sqlite:///{location_database_path}") monkeypatch.setenv("POO_DATABASE_URL", f"sqlite:///{poo_database_path}") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() app = create_app() anyio.run(_run_lifespan, app) @@ -124,7 +124,7 @@ def test_app_start_seeds_missing_config_from_env_without_overwriting_existing_va assert rows["AUTH_SESSION_COOKIE_NAME"] == "home_automation_session" get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def test_app_start_syncs_app_hostname_from_env_even_when_db_has_old_value( @@ -152,7 +152,7 @@ def test_app_start_syncs_app_hostname_from_env_even_when_db_has_old_value( monkeypatch.setenv("LOCATION_DATABASE_URL", f"sqlite:///{location_database_path}") monkeypatch.setenv("POO_DATABASE_URL", f"sqlite:///{poo_database_path}") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() app = create_app() anyio.run(_run_lifespan, app) @@ -166,7 +166,7 @@ def test_app_start_syncs_app_hostname_from_env_even_when_db_has_old_value( assert rows["APP_HOSTNAME"] == "new.example.com" get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def test_app_start_fails_when_location_db_missing( @@ -182,14 +182,14 @@ def test_app_start_fails_when_location_db_missing( monkeypatch.setenv("LOCATION_DATABASE_URL", f"sqlite:///{tmp_path / 'missing.db'}") monkeypatch.setenv("POO_DATABASE_URL", f"sqlite:///{poo_database_path}") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() app = create_app() with pytest.raises(RuntimeError, match="Run 'python scripts/location_db_adopt.py' first"): anyio.run(_run_lifespan, app) get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def test_app_start_fails_when_location_db_exists_but_is_not_adopted( @@ -223,14 +223,14 @@ def test_app_start_fails_when_location_db_exists_but_is_not_adopted( monkeypatch.setenv("LOCATION_DATABASE_URL", f"sqlite:///{database_path}") monkeypatch.setenv("POO_DATABASE_URL", f"sqlite:///{poo_database_path}") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() app = create_app() with pytest.raises(RuntimeError, match="is not yet Alembic-managed"): anyio.run(_run_lifespan, app) get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def test_app_start_fails_when_location_db_revision_mismatches( @@ -254,11 +254,11 @@ def test_app_start_fails_when_location_db_revision_mismatches( monkeypatch.setenv("LOCATION_DATABASE_URL", f"sqlite:///{database_path}") monkeypatch.setenv("POO_DATABASE_URL", f"sqlite:///{poo_database_path}") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() app = create_app() with pytest.raises(RuntimeError, match="Location DB revision mismatch"): anyio.run(_run_lifespan, app) get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() diff --git a/tests/test_auth.py b/tests/test_auth.py index b0d8c56..2558a5e 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -4,7 +4,7 @@ from pathlib import Path from fastapi.testclient import TestClient -from app.auth_db import reset_auth_db_caches +from app.db import reset_db_caches from app.config import get_settings from app.main import create_app @@ -205,7 +205,7 @@ def test_config_page_shows_ticktick_oauth_link_when_ticktick_is_configured( monkeypatch.setenv("TICKTICK_CLIENT_ID", "ticktick-client-id") monkeypatch.setenv("TICKTICK_CLIENT_SECRET", "ticktick-client-secret") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() with TestClient(create_app()) as client: login_page = client.get("/login") diff --git a/tests/test_deployment.py b/tests/test_deployment.py index d46691f..f2b4f28 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -6,7 +6,7 @@ import pytest import yaml from alembic import command -from app.auth_db import reset_auth_db_caches +from app.db import reset_db_caches from app.config import get_settings from app.main import create_app from scripts.app_db_adopt import APP_BASELINE_REVISION @@ -41,7 +41,7 @@ def _configure_database_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setenv("AUTH_BOOTSTRAP_PASSWORD", "test-password") monkeypatch.setenv("AUTH_COOKIE_SECURE_OVERRIDE", "false") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() return { "app_path": app_path, @@ -165,7 +165,7 @@ def test_migration_runner_initializes_and_is_idempotent( conn.close() get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def test_migration_runner_adopts_legacy_sqlite_without_data_loss( @@ -194,7 +194,7 @@ def test_migration_runner_adopts_legacy_sqlite_without_data_loss( conn.close() get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() def test_app_startup_still_fails_closed_without_running_adoption( @@ -212,4 +212,4 @@ def test_app_startup_still_fails_closed_without_running_adoption( assert not Path(missing_app_path).exists() get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() diff --git a/tests/test_homeassistant_inbound.py b/tests/test_homeassistant_inbound.py index a3f9753..47407f7 100644 --- a/tests/test_homeassistant_inbound.py +++ b/tests/test_homeassistant_inbound.py @@ -1,7 +1,5 @@ from sqlalchemy import text -import app.db as app_db -import app.poo_db as poo_db from app.config import Settings, get_settings from app.dependencies import get_app_settings, get_homeassistant_client from app.main import create_app @@ -161,42 +159,24 @@ def test_homeassistant_publish_poo_get_latest_publishes_latest_status( ready_location_database, ready_poo_database, auth_database, - monkeypatch, ) -> None: - location_engine = app_db.create_engine( - ready_location_database["location_url"], - connect_args={"check_same_thread": False}, - ) - location_session_local = app_db.sessionmaker( - bind=location_engine, - autoflush=False, - autocommit=False, - ) - poo_engine = poo_db.create_engine( - ready_poo_database["poo_url"], - connect_args={"check_same_thread": False}, - ) - poo_session_local = poo_db.sessionmaker( - bind=poo_engine, - autoflush=False, - autocommit=False, - ) + from fastapi.testclient import TestClient + from sqlalchemy import create_engine + + app_url = auth_database["app_url"] + engine = create_engine(app_url, connect_args={"check_same_thread": False}) + fake_ha = _FakeHomeAssistantClient() settings = Settings( poo_sensor_entity_name="sensor.test_poo_status", poo_sensor_friendly_name="Poo Status", ) - monkeypatch.setattr(app_db, "engine", location_engine) - monkeypatch.setattr(app_db, "SessionLocal", location_session_local) - monkeypatch.setattr(poo_db, "poo_engine", poo_engine) - monkeypatch.setattr(poo_db, "PooSessionLocal", poo_session_local) - test_app = create_app() test_app.dependency_overrides[get_homeassistant_client] = lambda: fake_ha test_app.dependency_overrides[get_app_settings] = lambda: settings - with poo_engine.begin() as conn: + with engine.begin() as conn: conn.execute( text( "INSERT INTO poo_records (timestamp, status, latitude, longitude) " @@ -211,8 +191,6 @@ def test_homeassistant_publish_poo_get_latest_publishes_latest_status( ) try: - from fastapi.testclient import TestClient - with TestClient(test_app) as client: response = client.post( "/homeassistant/publish", @@ -233,52 +211,27 @@ def test_homeassistant_publish_poo_get_latest_publishes_latest_status( finally: test_app.dependency_overrides.clear() get_settings.cache_clear() - location_engine.dispose() - poo_engine.dispose() + engine.dispose() def test_homeassistant_publish_returns_internal_error_for_unknown_poo_action( ready_location_database, ready_poo_database, auth_database, - monkeypatch, ) -> None: - location_engine = app_db.create_engine( - ready_location_database["location_url"], - connect_args={"check_same_thread": False}, - ) - location_session_local = app_db.sessionmaker( - bind=location_engine, - autoflush=False, - autocommit=False, - ) - poo_engine = poo_db.create_engine( - ready_poo_database["poo_url"], - connect_args={"check_same_thread": False}, - ) - poo_session_local = poo_db.sessionmaker( - bind=poo_engine, - autoflush=False, - autocommit=False, - ) + from fastapi.testclient import TestClient + fake_ha = _FakeHomeAssistantClient() settings = Settings( poo_sensor_entity_name="sensor.test_poo_status", poo_sensor_friendly_name="Poo Status", ) - monkeypatch.setattr(app_db, "engine", location_engine) - monkeypatch.setattr(app_db, "SessionLocal", location_session_local) - monkeypatch.setattr(poo_db, "poo_engine", poo_engine) - monkeypatch.setattr(poo_db, "PooSessionLocal", poo_session_local) - test_app = create_app() test_app.dependency_overrides[get_homeassistant_client] = lambda: fake_ha test_app.dependency_overrides[get_app_settings] = lambda: settings try: - from fastapi.testclient import TestClient - with TestClient(test_app) as client: response = client.post( "/homeassistant/publish", @@ -295,8 +248,6 @@ def test_homeassistant_publish_returns_internal_error_for_unknown_poo_action( finally: test_app.dependency_overrides.clear() get_settings.cache_clear() - location_engine.dispose() - poo_engine.dispose() def test_homeassistant_publish_returns_not_implemented_for_unknown_location_action( diff --git a/tests/test_location.py b/tests/test_location.py index 3af7c6e..f510245 100644 --- a/tests/test_location.py +++ b/tests/test_location.py @@ -5,18 +5,14 @@ import sqlite3 import pytest from alembic import command from alembic.config import Config -from sqlalchemy import create_engine, text -from sqlalchemy.orm import sessionmaker +from sqlalchemy import text -import app.db as app_db -from app.main import create_app from scripts.location_db_adopt import ( EXPECTED_USER_VERSION, LOCATION_BASELINE_REVISION, LocationDatabaseAdoptionError, adopt_or_initialize_location_db, ) -from tests.conftest import _make_app_alembic_config, _make_poo_alembic_config def _make_alembic_config(database_url: str) -> Config: @@ -197,66 +193,6 @@ def test_location_record_endpoint_defaults_invalid_altitude_to_zero(location_cli assert row.altitude == pytest.approx(0.0) -def test_legacy_style_location_db_can_be_stamped_and_adopted( - test_database_urls, monkeypatch: pytest.MonkeyPatch -) -> None: - app_database_url = test_database_urls["app_url"] - database_path = test_database_urls["location_path"] - database_url = test_database_urls["location_url"] - poo_database_url = test_database_urls["poo_url"] - - conn = sqlite3.connect(database_path) - conn.execute( - """ - CREATE TABLE location ( - person TEXT NOT NULL, - datetime TEXT NOT NULL, - latitude REAL NOT NULL, - longitude REAL NOT NULL, - altitude REAL, - PRIMARY KEY (person, datetime) - ) - """ - ) - conn.execute("PRAGMA user_version = 2") - conn.commit() - conn.close() - - command.upgrade(_make_app_alembic_config(app_database_url), "head") - command.stamp(_make_alembic_config(database_url), LOCATION_BASELINE_REVISION) - command.upgrade(_make_poo_alembic_config(poo_database_url), "head") - - engine = create_engine(database_url, connect_args={"check_same_thread": False}) - session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False) - monkeypatch.setattr(app_db, "engine", engine) - monkeypatch.setattr(app_db, "SessionLocal", session_local) - - from fastapi.testclient import TestClient - - fastapi_app = create_app() - with TestClient(fastapi_app) as client: - response = client.post( - "/location/record", - json={ - "person": "legacy-user", - "latitude": "12.3", - "longitude": "45.6", - "altitude": "7.8", - }, - ) - - assert response.status_code == 200 - - with engine.connect() as db_conn: - revision = db_conn.execute(text("SELECT version_num FROM alembic_version")).scalar_one() - row_count = db_conn.execute(text("SELECT COUNT(*) FROM location")).scalar_one() - - assert revision == LOCATION_BASELINE_REVISION - assert row_count == 1 - - engine.dispose() - - def test_location_db_adoption_initializes_new_db(tmp_path: Path) -> None: database_path = tmp_path / "new_location.db" result = adopt_or_initialize_location_db(f"sqlite:///{database_path}") diff --git a/tests/test_ticktick.py b/tests/test_ticktick.py index 32ea1a9..036a486 100644 --- a/tests/test_ticktick.py +++ b/tests/test_ticktick.py @@ -6,7 +6,7 @@ from urllib.parse import parse_qs, urlparse import pytest from fastapi.testclient import TestClient -from app.auth_db import reset_auth_db_caches +from app.db import reset_db_caches from app.config import Settings, get_settings from app.integrations.ticktick import ( AUTH_SCOPE, @@ -221,7 +221,7 @@ def test_homeassistant_publish_creates_ticktick_action_task( monkeypatch.setenv("TICKTICK_TOKEN", "ticktick-access-token") monkeypatch.setenv("HOME_ASSISTANT_ACTION_TASK_PROJECT_ID", "project-123") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() captured = {"calls": []} @@ -265,7 +265,7 @@ def test_ticktick_auth_start_redirects_authenticated_user( monkeypatch.setenv("TICKTICK_CLIENT_ID", "ticktick-client-id") monkeypatch.setenv("TICKTICK_CLIENT_SECRET", "ticktick-client-secret") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() monkeypatch.setattr("app.integrations.ticktick.secrets.token_hex", lambda _: "state-redirect") with TestClient(create_app()) as client: @@ -301,7 +301,7 @@ def test_ticktick_auth_callback_persists_token( monkeypatch.setenv("TICKTICK_CLIENT_ID", "ticktick-client-id") monkeypatch.setenv("TICKTICK_CLIENT_SECRET", "ticktick-client-secret") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() default_auth_state_store.pending_state = "callback-state" def fake_urlopen(req, timeout): @@ -342,7 +342,7 @@ def test_ticktick_auth_callback_redirects_on_invalid_state( monkeypatch.setenv("TICKTICK_CLIENT_ID", "ticktick-client-id") monkeypatch.setenv("TICKTICK_CLIENT_SECRET", "ticktick-client-secret") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() default_auth_state_store.pending_state = "expected-state" with TestClient(create_app()) as client: @@ -366,7 +366,7 @@ def test_ticktick_auth_callback_redirects_when_token_exchange_fails( monkeypatch.setenv("TICKTICK_CLIENT_ID", "ticktick-client-id") monkeypatch.setenv("TICKTICK_CLIENT_SECRET", "ticktick-client-secret") get_settings.cache_clear() - reset_auth_db_caches() + reset_db_caches() default_auth_state_store.pending_state = "callback-state" def fake_urlopen(req, timeout):