Files
home-automation/tests/test_location.py
T

173 lines
4.9 KiB
Python
Raw Normal View History

from datetime import datetime
from pathlib import Path
import sqlite3
import pytest
from alembic import command
from alembic.config import Config
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
import app.db
from app.main import create_app
LOCATION_BASELINE_REVISION = "20260419_01_location_baseline"
def _make_alembic_config(database_url: str) -> Config:
config = Config("alembic.ini")
config.set_main_option("sqlalchemy.url", database_url)
return config
@pytest.fixture
def location_client(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
database_path = tmp_path / "location_test.db"
database_url = f"sqlite:///{database_path}"
command.upgrade(_make_alembic_config(database_url), "head")
engine = create_engine(database_url, connect_args={"check_same_thread": False})
session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False)
monkeypatch.setattr(app.db, "engine", engine)
monkeypatch.setattr(app.db, "SessionLocal", session_local)
from fastapi.testclient import TestClient
fastapi_app = create_app()
with TestClient(fastapi_app) as client:
yield client, engine
engine.dispose()
def test_location_record_endpoint_writes_row(location_client) -> None:
client, engine = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "1.23",
"longitude": "4.56",
"altitude": "7.89",
},
)
assert response.status_code == 200
assert response.text == ""
with engine.connect() as conn:
row = conn.execute(
text(
"SELECT person, datetime, latitude, longitude, altitude "
"FROM location ORDER BY datetime DESC LIMIT 1"
)
).one()
assert row.person == "tianyu"
assert row.latitude == pytest.approx(1.23)
assert row.longitude == pytest.approx(4.56)
assert row.altitude == pytest.approx(7.89)
datetime.fromisoformat(row.datetime.replace("Z", "+00:00"))
def test_location_record_endpoint_rejects_unknown_fields(location_client) -> None:
client, _ = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "1.23",
"longitude": "4.56",
"extra": "not-allowed",
},
)
assert response.status_code == 400
def test_location_record_endpoint_keeps_legacy_lenient_number_parsing(location_client) -> None:
client, engine = location_client
response = client.post(
"/location/record",
json={
"person": "tianyu",
"latitude": "bad-lat",
"longitude": "bad-long",
"altitude": "bad-alt",
},
)
assert response.status_code == 200
with engine.connect() as conn:
row = conn.execute(
text(
"SELECT latitude, longitude, altitude "
"FROM location ORDER BY datetime DESC LIMIT 1"
)
).one()
assert row.latitude == pytest.approx(0.0)
assert row.longitude == pytest.approx(0.0)
assert row.altitude == pytest.approx(0.0)
def test_legacy_style_location_db_can_be_stamped_and_adopted(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
database_path = tmp_path / "legacy_location.db"
conn = sqlite3.connect(database_path)
conn.execute(
"""
CREATE TABLE location (
person TEXT NOT NULL,
datetime TEXT NOT NULL,
latitude REAL NOT NULL,
longitude REAL NOT NULL,
altitude REAL,
PRIMARY KEY (person, datetime)
)
"""
)
conn.execute("PRAGMA user_version = 2")
conn.commit()
conn.close()
database_url = f"sqlite:///{database_path}"
command.stamp(_make_alembic_config(database_url), LOCATION_BASELINE_REVISION)
engine = create_engine(database_url, connect_args={"check_same_thread": False})
session_local = sessionmaker(bind=engine, autoflush=False, autocommit=False)
monkeypatch.setattr(app.db, "engine", engine)
monkeypatch.setattr(app.db, "SessionLocal", session_local)
from fastapi.testclient import TestClient
fastapi_app = create_app()
with TestClient(fastapi_app) as client:
response = client.post(
"/location/record",
json={
"person": "legacy-user",
"latitude": "12.3",
"longitude": "45.6",
"altitude": "7.8",
},
)
assert response.status_code == 200
with engine.connect() as db_conn:
revision = db_conn.execute(text("SELECT version_num FROM alembic_version")).scalar_one()
row_count = db_conn.execute(text("SELECT COUNT(*) FROM location")).scalar_one()
assert revision == LOCATION_BASELINE_REVISION
assert row_count == 1
engine.dispose()