Files
home-automation/tests/test_location.py
T

393 lines
11 KiB
Python

from datetime import datetime
from pathlib import Path
import sqlite3
import pytest
from alembic import command
from alembic.config import Config
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
import app.db as app_db
from app.main import create_app
from scripts.location_db_adopt import (
EXPECTED_USER_VERSION,
LOCATION_BASELINE_REVISION,
LocationDatabaseAdoptionError,
adopt_or_initialize_location_db,
)
from tests.conftest import _make_app_alembic_config, _make_poo_alembic_config
def _make_alembic_config(database_url: str) -> Config:
config = Config("alembic_location.ini")
config.set_main_option("sqlalchemy.url", database_url)
return config
def test_location_record_endpoint_writes_row(location_client) -> None:
client, engine = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "1.23",
"longitude": "4.56",
"altitude": "7.89",
},
)
assert response.status_code == 200
assert response.text == ""
with engine.connect() as conn:
row = conn.execute(
text(
"SELECT person, datetime, latitude, longitude, altitude "
"FROM location ORDER BY datetime DESC LIMIT 1"
)
).one()
assert row.person == "tianyu"
assert row.latitude == pytest.approx(1.23)
assert row.longitude == pytest.approx(4.56)
assert row.altitude == pytest.approx(7.89)
datetime.fromisoformat(row.datetime.replace("Z", "+00:00"))
def test_location_record_endpoint_rejects_unknown_fields(location_client) -> None:
client, _ = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "1.23",
"longitude": "4.56",
"extra": "not-allowed",
},
)
assert response.status_code == 400
assert response.text == "bad request"
assert "extra" not in response.text
assert "ValidationError" not in response.text
def test_location_record_endpoint_rejects_missing_latitude(location_client) -> None:
client, _ = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"longitude": "4.56",
},
)
assert response.status_code == 400
assert response.text == "bad request"
assert "latitude" not in response.text
def test_location_record_endpoint_rejects_missing_longitude(location_client) -> None:
client, _ = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "1.23",
},
)
assert response.status_code == 400
assert response.text == "bad request"
assert "longitude" not in response.text
def test_location_record_endpoint_rejects_invalid_latitude(location_client) -> None:
client, _ = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "bad-lat",
"longitude": "4.56",
},
)
assert response.status_code == 400
assert response.text == "bad request"
assert "bad-lat" not in response.text
assert "latitude" not in response.text
def test_location_record_endpoint_rejects_invalid_longitude(location_client) -> None:
client, _ = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "1.23",
"longitude": "bad-long",
},
)
assert response.status_code == 400
assert response.text == "bad request"
assert "bad-long" not in response.text
assert "longitude" not in response.text
def test_location_record_endpoint_defaults_missing_altitude_to_zero(location_client) -> None:
client, engine = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "1.23",
"longitude": "4.56",
},
)
assert response.status_code == 200
with engine.connect() as conn:
row = conn.execute(
text(
"SELECT latitude, longitude, altitude "
"FROM location ORDER BY datetime DESC LIMIT 1"
)
).one()
assert row.latitude == pytest.approx(1.23)
assert row.longitude == pytest.approx(4.56)
assert row.altitude == pytest.approx(0.0)
def test_location_record_endpoint_defaults_invalid_altitude_to_zero(location_client) -> None:
client, engine = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "1.23",
"longitude": "4.56",
"altitude": "bad-alt",
},
)
assert response.status_code == 200
with engine.connect() as conn:
row = conn.execute(
text(
"SELECT latitude, longitude, altitude "
"FROM location ORDER BY datetime DESC LIMIT 1"
)
).one()
assert row.latitude == pytest.approx(1.23)
assert row.longitude == pytest.approx(4.56)
assert row.altitude == pytest.approx(0.0)
def test_legacy_style_location_db_can_be_stamped_and_adopted(
test_database_urls, monkeypatch: pytest.MonkeyPatch
) -> None:
app_database_url = test_database_urls["app_url"]
database_path = test_database_urls["location_path"]
database_url = test_database_urls["location_url"]
poo_database_url = test_database_urls["poo_url"]
conn = sqlite3.connect(database_path)
conn.execute(
"""
CREATE TABLE location (
person TEXT NOT NULL,
datetime TEXT NOT NULL,
latitude REAL NOT NULL,
longitude REAL NOT NULL,
altitude REAL,
PRIMARY KEY (person, datetime)
)
"""
)
conn.execute("PRAGMA user_version = 2")
conn.commit()
conn.close()
command.upgrade(_make_app_alembic_config(app_database_url), "head")
command.stamp(_make_alembic_config(database_url), LOCATION_BASELINE_REVISION)
command.upgrade(_make_poo_alembic_config(poo_database_url), "head")
engine = create_engine(database_url, connect_args={"check_same_thread": False})
session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False)
monkeypatch.setattr(app_db, "engine", engine)
monkeypatch.setattr(app_db, "SessionLocal", session_local)
from fastapi.testclient import TestClient
fastapi_app = create_app()
with TestClient(fastapi_app) as client:
response = client.post(
"/location/record",
json={
"person": "legacy-user",
"latitude": "12.3",
"longitude": "45.6",
"altitude": "7.8",
},
)
assert response.status_code == 200
with engine.connect() as db_conn:
revision = db_conn.execute(text("SELECT version_num FROM alembic_version")).scalar_one()
row_count = db_conn.execute(text("SELECT COUNT(*) FROM location")).scalar_one()
assert revision == LOCATION_BASELINE_REVISION
assert row_count == 1
engine.dispose()
def test_location_db_adoption_initializes_new_db(tmp_path: Path) -> None:
database_path = tmp_path / "new_location.db"
result = adopt_or_initialize_location_db(f"sqlite:///{database_path}")
assert result == "initialized"
assert database_path.exists()
conn = sqlite3.connect(database_path)
try:
revision = conn.execute("SELECT version_num FROM alembic_version").fetchone()[0]
location_table = conn.execute(
"SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'location'"
).fetchone()
finally:
conn.close()
assert revision == LOCATION_BASELINE_REVISION
assert location_table is not None
def test_location_db_adoption_validates_and_stamps_legacy_db(tmp_path: Path) -> None:
database_path = tmp_path / "legacy_location.db"
conn = sqlite3.connect(database_path)
conn.execute(
"""
CREATE TABLE location (
person TEXT NOT NULL,
datetime TEXT NOT NULL,
latitude REAL NOT NULL,
longitude REAL NOT NULL,
altitude REAL,
PRIMARY KEY (person, datetime)
)
"""
)
conn.execute(f"PRAGMA user_version = {EXPECTED_USER_VERSION}")
conn.commit()
conn.close()
result = adopt_or_initialize_location_db(f"sqlite:///{database_path}")
assert result == "adopted"
conn = sqlite3.connect(database_path)
try:
revision = conn.execute("SELECT version_num FROM alembic_version").fetchone()[0]
finally:
conn.close()
assert revision == LOCATION_BASELINE_REVISION
def test_location_db_adoption_accepts_already_managed_matching_revision(
tmp_path: Path,
) -> None:
database_path = tmp_path / "managed_location.db"
command.upgrade(_make_alembic_config(f"sqlite:///{database_path}"), "head")
result = adopt_or_initialize_location_db(f"sqlite:///{database_path}")
assert result == "already_managed"
def test_location_db_adoption_fails_closed_on_alembic_revision_mismatch(
tmp_path: Path,
) -> None:
database_path = tmp_path / "wrong_revision.db"
conn = sqlite3.connect(database_path)
conn.execute(
"""
CREATE TABLE location (
person TEXT NOT NULL,
datetime TEXT NOT NULL,
latitude REAL NOT NULL,
longitude REAL NOT NULL,
altitude REAL,
PRIMARY KEY (person, datetime)
)
"""
)
conn.execute("CREATE TABLE alembic_version (version_num VARCHAR(32) NOT NULL)")
conn.execute("INSERT INTO alembic_version (version_num) VALUES ('wrong_revision')")
conn.execute(f"PRAGMA user_version = {EXPECTED_USER_VERSION}")
conn.commit()
conn.close()
with pytest.raises(LocationDatabaseAdoptionError, match="revision does not match"):
adopt_or_initialize_location_db(f"sqlite:///{database_path}")
def test_location_db_adoption_fails_closed_on_schema_mismatch(tmp_path: Path) -> None:
database_path = tmp_path / "bad_schema.db"
conn = sqlite3.connect(database_path)
conn.execute(
"""
CREATE TABLE location (
person TEXT NOT NULL,
datetime TEXT NOT NULL,
latitude REAL NOT NULL,
longitude REAL NOT NULL,
PRIMARY KEY (person, datetime)
)
"""
)
conn.execute(f"PRAGMA user_version = {EXPECTED_USER_VERSION}")
conn.commit()
conn.close()
with pytest.raises(LocationDatabaseAdoptionError, match="schema does not match"):
adopt_or_initialize_location_db(f"sqlite:///{database_path}")
def test_location_db_adoption_fails_closed_on_user_version_mismatch(tmp_path: Path) -> None:
database_path = tmp_path / "bad_user_version.db"
conn = sqlite3.connect(database_path)
conn.execute(
"""
CREATE TABLE location (
person TEXT NOT NULL,
datetime TEXT NOT NULL,
latitude REAL NOT NULL,
longitude REAL NOT NULL,
altitude REAL,
PRIMARY KEY (person, datetime)
)
"""
)
conn.execute("PRAGMA user_version = 999")
conn.commit()
conn.close()
with pytest.raises(LocationDatabaseAdoptionError, match="Expected PRAGMA user_version"):
adopt_or_initialize_location_db(f"sqlite:///{database_path}")