from datetime import datetime from pathlib import Path import sqlite3 import pytest from alembic import command from alembic.config import Config from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker import app.db from app.main import create_app from scripts.location_db_adopt import ( EXPECTED_USER_VERSION, LOCATION_BASELINE_REVISION, LocationDatabaseAdoptionError, adopt_or_initialize_location_db, ) def _make_alembic_config(database_url: str) -> Config: config = Config("alembic.ini") config.set_main_option("sqlalchemy.url", database_url) return config @pytest.fixture def location_client(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): database_path = tmp_path / "location_test.db" database_url = f"sqlite:///{database_path}" command.upgrade(_make_alembic_config(database_url), "head") engine = create_engine(database_url, connect_args={"check_same_thread": False}) session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False) monkeypatch.setattr(app.db, "engine", engine) monkeypatch.setattr(app.db, "SessionLocal", session_local) from fastapi.testclient import TestClient fastapi_app = create_app() with TestClient(fastapi_app) as client: yield client, engine engine.dispose() def test_location_record_endpoint_writes_row(location_client) -> None: client, engine = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "1.23", "longitude": "4.56", "altitude": "7.89", }, ) assert response.status_code == 200 assert response.text == "" with engine.connect() as conn: row = conn.execute( text( "SELECT person, datetime, latitude, longitude, altitude " "FROM location ORDER BY datetime DESC LIMIT 1" ) ).one() assert row.person == "tianyu" assert row.latitude == pytest.approx(1.23) assert row.longitude == pytest.approx(4.56) assert row.altitude == pytest.approx(7.89) datetime.fromisoformat(row.datetime.replace("Z", "+00:00")) def test_location_record_endpoint_rejects_unknown_fields(location_client) -> None: client, _ = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "1.23", "longitude": "4.56", "extra": "not-allowed", }, ) assert response.status_code == 400 def test_location_record_endpoint_keeps_legacy_lenient_number_parsing(location_client) -> None: client, engine = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "bad-lat", "longitude": "bad-long", "altitude": "bad-alt", }, ) assert response.status_code == 200 with engine.connect() as conn: row = conn.execute( text( "SELECT latitude, longitude, altitude " "FROM location ORDER BY datetime DESC LIMIT 1" ) ).one() assert row.latitude == pytest.approx(0.0) assert row.longitude == pytest.approx(0.0) assert row.altitude == pytest.approx(0.0) def test_legacy_style_location_db_can_be_stamped_and_adopted( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: database_path = tmp_path / "legacy_location.db" conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, altitude REAL, PRIMARY KEY (person, datetime) ) """ ) conn.execute("PRAGMA user_version = 2") conn.commit() conn.close() database_url = f"sqlite:///{database_path}" command.stamp(_make_alembic_config(database_url), LOCATION_BASELINE_REVISION) engine = create_engine(database_url, connect_args={"check_same_thread": False}) session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False) monkeypatch.setattr(app.db, "engine", engine) monkeypatch.setattr(app.db, "SessionLocal", session_local) from fastapi.testclient import TestClient fastapi_app = create_app() with TestClient(fastapi_app) as client: response = client.post( "/location/record", json={ "person": "legacy-user", "latitude": "12.3", "longitude": "45.6", "altitude": "7.8", }, ) assert response.status_code == 200 with engine.connect() as db_conn: revision = db_conn.execute(text("SELECT version_num FROM alembic_version")).scalar_one() row_count = db_conn.execute(text("SELECT COUNT(*) FROM location")).scalar_one() assert revision == LOCATION_BASELINE_REVISION assert row_count == 1 engine.dispose() def test_location_db_adoption_initializes_new_db(tmp_path: Path) -> None: database_path = tmp_path / "new_location.db" result = adopt_or_initialize_location_db(f"sqlite:///{database_path}") assert result == "initialized" assert database_path.exists() conn = sqlite3.connect(database_path) try: revision = conn.execute("SELECT version_num FROM alembic_version").fetchone()[0] location_table = conn.execute( "SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'location'" ).fetchone() finally: conn.close() assert revision == LOCATION_BASELINE_REVISION assert location_table is not None def test_location_db_adoption_validates_and_stamps_legacy_db(tmp_path: Path) -> None: database_path = tmp_path / "legacy_location.db" conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, altitude REAL, PRIMARY KEY (person, datetime) ) """ ) conn.execute(f"PRAGMA user_version = {EXPECTED_USER_VERSION}") conn.commit() conn.close() result = adopt_or_initialize_location_db(f"sqlite:///{database_path}") assert result == "adopted" conn = sqlite3.connect(database_path) try: revision = conn.execute("SELECT version_num FROM alembic_version").fetchone()[0] finally: conn.close() assert revision == LOCATION_BASELINE_REVISION def test_location_db_adoption_fails_closed_on_schema_mismatch(tmp_path: Path) -> None: database_path = tmp_path / "bad_schema.db" conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, PRIMARY KEY (person, datetime) ) """ ) conn.execute(f"PRAGMA user_version = {EXPECTED_USER_VERSION}") conn.commit() conn.close() with pytest.raises(LocationDatabaseAdoptionError, match="schema does not match"): adopt_or_initialize_location_db(f"sqlite:///{database_path}") def test_location_db_adoption_fails_closed_on_user_version_mismatch(tmp_path: Path) -> None: database_path = tmp_path / "bad_user_version.db" conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, altitude REAL, PRIMARY KEY (person, datetime) ) """ ) conn.execute("PRAGMA user_version = 999") conn.commit() conn.close() with pytest.raises(LocationDatabaseAdoptionError, match="Expected PRAGMA user_version"): adopt_or_initialize_location_db(f"sqlite:///{database_path}")