from datetime import datetime from pathlib import Path import sqlite3 import pytest from alembic import command from alembic.config import Config from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker import app.db as app_db from app.main import create_app from scripts.location_db_adopt import ( EXPECTED_USER_VERSION, LOCATION_BASELINE_REVISION, LocationDatabaseAdoptionError, adopt_or_initialize_location_db, ) from tests.conftest import _make_poo_alembic_config def _make_alembic_config(database_url: str) -> Config: config = Config("alembic_location.ini") config.set_main_option("sqlalchemy.url", database_url) return config def test_location_record_endpoint_writes_row(location_client) -> None: client, engine = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "1.23", "longitude": "4.56", "altitude": "7.89", }, ) assert response.status_code == 200 assert response.text == "" with engine.connect() as conn: row = conn.execute( text( "SELECT person, datetime, latitude, longitude, altitude " "FROM location ORDER BY datetime DESC LIMIT 1" ) ).one() assert row.person == "tianyu" assert row.latitude == pytest.approx(1.23) assert row.longitude == pytest.approx(4.56) assert row.altitude == pytest.approx(7.89) datetime.fromisoformat(row.datetime.replace("Z", "+00:00")) def test_location_record_endpoint_rejects_unknown_fields(location_client) -> None: client, _ = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "1.23", "longitude": "4.56", "extra": "not-allowed", }, ) assert response.status_code == 400 assert response.text == "bad request" assert "extra" not in response.text assert "ValidationError" not in response.text def test_location_record_endpoint_rejects_missing_latitude(location_client) -> None: client, _ = location_client response = client.post( "/location/record", json={ "person": "tianyu", "longitude": "4.56", }, ) assert response.status_code == 400 assert response.text == "bad request" assert "latitude" not in response.text def test_location_record_endpoint_rejects_missing_longitude(location_client) -> None: client, _ = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "1.23", }, ) assert response.status_code == 400 assert response.text == "bad request" assert "longitude" not in response.text def test_location_record_endpoint_rejects_invalid_latitude(location_client) -> None: client, _ = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "bad-lat", "longitude": "4.56", }, ) assert response.status_code == 400 assert response.text == "bad request" assert "bad-lat" not in response.text assert "latitude" not in response.text def test_location_record_endpoint_rejects_invalid_longitude(location_client) -> None: client, _ = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "1.23", "longitude": "bad-long", }, ) assert response.status_code == 400 assert response.text == "bad request" assert "bad-long" not in response.text assert "longitude" not in response.text def test_location_record_endpoint_defaults_missing_altitude_to_zero(location_client) -> None: client, engine = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "1.23", "longitude": "4.56", }, ) assert response.status_code == 200 with engine.connect() as conn: row = conn.execute( text( "SELECT latitude, longitude, altitude " "FROM location ORDER BY datetime DESC LIMIT 1" ) ).one() assert row.latitude == pytest.approx(1.23) assert row.longitude == pytest.approx(4.56) assert row.altitude == pytest.approx(0.0) def test_location_record_endpoint_defaults_invalid_altitude_to_zero(location_client) -> None: client, engine = location_client response = client.post( "/location/record", json={ "person": "tianyu", "latitude": "1.23", "longitude": "4.56", "altitude": "bad-alt", }, ) assert response.status_code == 200 with engine.connect() as conn: row = conn.execute( text( "SELECT latitude, longitude, altitude " "FROM location ORDER BY datetime DESC LIMIT 1" ) ).one() assert row.latitude == pytest.approx(1.23) assert row.longitude == pytest.approx(4.56) assert row.altitude == pytest.approx(0.0) def test_legacy_style_location_db_can_be_stamped_and_adopted( test_database_urls, monkeypatch: pytest.MonkeyPatch ) -> None: database_path = test_database_urls["location_path"] database_url = test_database_urls["location_url"] poo_database_url = test_database_urls["poo_url"] conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, altitude REAL, PRIMARY KEY (person, datetime) ) """ ) conn.execute("PRAGMA user_version = 2") conn.commit() conn.close() command.stamp(_make_alembic_config(database_url), LOCATION_BASELINE_REVISION) command.upgrade(_make_poo_alembic_config(poo_database_url), "head") engine = create_engine(database_url, connect_args={"check_same_thread": False}) session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False) monkeypatch.setattr(app_db, "engine", engine) monkeypatch.setattr(app_db, "SessionLocal", session_local) from fastapi.testclient import TestClient fastapi_app = create_app() with TestClient(fastapi_app) as client: response = client.post( "/location/record", json={ "person": "legacy-user", "latitude": "12.3", "longitude": "45.6", "altitude": "7.8", }, ) assert response.status_code == 200 with engine.connect() as db_conn: revision = db_conn.execute(text("SELECT version_num FROM alembic_version")).scalar_one() row_count = db_conn.execute(text("SELECT COUNT(*) FROM location")).scalar_one() assert revision == LOCATION_BASELINE_REVISION assert row_count == 1 engine.dispose() def test_location_db_adoption_initializes_new_db(tmp_path: Path) -> None: database_path = tmp_path / "new_location.db" result = adopt_or_initialize_location_db(f"sqlite:///{database_path}") assert result == "initialized" assert database_path.exists() conn = sqlite3.connect(database_path) try: revision = conn.execute("SELECT version_num FROM alembic_version").fetchone()[0] location_table = conn.execute( "SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'location'" ).fetchone() finally: conn.close() assert revision == LOCATION_BASELINE_REVISION assert location_table is not None def test_location_db_adoption_validates_and_stamps_legacy_db(tmp_path: Path) -> None: database_path = tmp_path / "legacy_location.db" conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, altitude REAL, PRIMARY KEY (person, datetime) ) """ ) conn.execute(f"PRAGMA user_version = {EXPECTED_USER_VERSION}") conn.commit() conn.close() result = adopt_or_initialize_location_db(f"sqlite:///{database_path}") assert result == "adopted" conn = sqlite3.connect(database_path) try: revision = conn.execute("SELECT version_num FROM alembic_version").fetchone()[0] finally: conn.close() assert revision == LOCATION_BASELINE_REVISION def test_location_db_adoption_accepts_already_managed_matching_revision( tmp_path: Path, ) -> None: database_path = tmp_path / "managed_location.db" command.upgrade(_make_alembic_config(f"sqlite:///{database_path}"), "head") result = adopt_or_initialize_location_db(f"sqlite:///{database_path}") assert result == "already_managed" def test_location_db_adoption_fails_closed_on_alembic_revision_mismatch( tmp_path: Path, ) -> None: database_path = tmp_path / "wrong_revision.db" conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, altitude REAL, PRIMARY KEY (person, datetime) ) """ ) conn.execute("CREATE TABLE alembic_version (version_num VARCHAR(32) NOT NULL)") conn.execute("INSERT INTO alembic_version (version_num) VALUES ('wrong_revision')") conn.execute(f"PRAGMA user_version = {EXPECTED_USER_VERSION}") conn.commit() conn.close() with pytest.raises(LocationDatabaseAdoptionError, match="revision does not match"): adopt_or_initialize_location_db(f"sqlite:///{database_path}") def test_location_db_adoption_fails_closed_on_schema_mismatch(tmp_path: Path) -> None: database_path = tmp_path / "bad_schema.db" conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, PRIMARY KEY (person, datetime) ) """ ) conn.execute(f"PRAGMA user_version = {EXPECTED_USER_VERSION}") conn.commit() conn.close() with pytest.raises(LocationDatabaseAdoptionError, match="schema does not match"): adopt_or_initialize_location_db(f"sqlite:///{database_path}") def test_location_db_adoption_fails_closed_on_user_version_mismatch(tmp_path: Path) -> None: database_path = tmp_path / "bad_user_version.db" conn = sqlite3.connect(database_path) conn.execute( """ CREATE TABLE location ( person TEXT NOT NULL, datetime TEXT NOT NULL, latitude REAL NOT NULL, longitude REAL NOT NULL, altitude REAL, PRIMARY KEY (person, datetime) ) """ ) conn.execute("PRAGMA user_version = 999") conn.commit() conn.close() with pytest.raises(LocationDatabaseAdoptionError, match="Expected PRAGMA user_version"): adopt_or_initialize_location_db(f"sqlite:///{database_path}")