diff --git a/requirements.txt b/requirements.txt index b523b95..e709024 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ fastapi==0.112.1 fastapi-cli==0.0.5 fastapi-mqtt==2.2.0 gmqtt==0.6.16 +gpxpy==1.6.2 greenlet==3.0.3 h11==0.14.0 httpcore==1.0.5 @@ -29,6 +30,7 @@ pydantic==2.8.2 pydantic_core==2.20.1 Pygments==2.18.0 pytest==8.3.2 +pytest-asyncio==0.24.0 python-dotenv==1.0.1 python-multipart==0.0.9 PyYAML==6.0.2 diff --git a/src/util/location_recorder.py b/src/util/location_recorder.py index 16e4889..8b82791 100644 --- a/src/util/location_recorder.py +++ b/src/util/location_recorder.py @@ -12,11 +12,11 @@ class Base(DeclarativeBase): class Location(Base): __tablename__ = "location" - person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) - datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) - latitude: Mapped[float] = mapped_column(type_=REAL) - longitude: Mapped[float] = mapped_column(type_=REAL) - altitude: Mapped[float] = mapped_column(type_=REAL) + person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False) + datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False) + latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False) + longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False) + altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True) @dataclass @@ -27,7 +27,7 @@ class LocationData: class LocationRecorder: - USER_VERSION = 2 + USER_VERSION = 3 def __init__(self, db_path: str) -> None: self._db_path = "sqlite+aiosqlite:///" + db_path @@ -49,9 +49,11 @@ class LocationRecorder: if date_time.tzinfo != UTC: date_time = date_time.astimezone(UTC) date_time_str = date_time.strftime("%Y-%m-%dT%H:%M:%S%z") - async with self._engine.connect() as conn: + async with self._engine.begin() as conn: await conn.execute( - insert(Location).values( + insert(Location) + .prefix_with("OR IGNORE") + .values( person=person, datetime=date_time_str, latitude=location.latitude, @@ -59,8 +61,6 @@ class LocationRecorder: altitude=location.altitude, ), ) - await conn.commit() - await conn.aclose() async def insert_location_now(self, person: str, location: LocationData) -> None: now_utc = datetime.now(tz=UTC) @@ -77,9 +77,25 @@ class LocationRecorder: if user_version == 1: await self._migrate_1_2(conn=conn) user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0] + if user_version == 2: # noqa: PLR2004 + await self._migrate_2_3(conn=conn) + user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0] async def _migrate_1_2(self, conn: AsyncConnection) -> None: print("Location Recorder: migrate from db ver 1 to 2.") await conn.execute(text("DROP TABLE version")) await conn.execute(text("ALTER TABLE location RENAME people TO person")) await self._set_user_version(conn=conn, user_version=2) + + async def _migrate_2_3(self, conn: AsyncConnection) -> None: + print("Location Recorder: migrate from db ver 2 to 3.") + await conn.execute(text("ALTER TABLE location RENAME TO location_old")) + await conn.run_sync(Base.metadata.create_all) + await conn.execute( + text(""" + INSERT INTO location (person, datetime, latitude, longitude, altitude) + SELECT person, datetime, latitude, longitude, altitude FROM location_old; + """), + ) + await conn.execute(text("DROP TABLE location_old")) + await self._set_user_version(conn=conn, user_version=3) diff --git a/src/util/tests/test_location_recorder.py b/src/util/tests/test_location_recorder.py index 0775a18..b0af54b 100644 --- a/src/util/tests/test_location_recorder.py +++ b/src/util/tests/test_location_recorder.py @@ -14,6 +14,18 @@ DB_PATH = Path(__file__).resolve().parent / "test.db" DB_PATH_STR = str(DB_PATH) +@pytest.fixture +def _reset_event_loop() -> any: + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.stop() + loop.close() + except RuntimeError: + pass + asyncio.set_event_loop(asyncio.new_event_loop()) + + @pytest.fixture def _teardown() -> any: yield @@ -65,6 +77,28 @@ def _create_v2_db() -> None: engine = create_engine(db) Base.metadata.create_all(engine) + with engine.begin() as conn: + conn.execute(text("PRAGMA user_version = 2")) + + +@pytest.fixture +def _create_latest_db() -> None: + db = "sqlite:///" + DB_PATH_STR + + class Base(DeclarativeBase): + pass + + class Location(Base): + __tablename__ = "location" + person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False) + datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False) + latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False) + longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False) + altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True) + + engine = create_engine(db) + + Base.metadata.create_all(engine) @pytest.mark.usefixtures("_create_v1_db") @@ -116,10 +150,57 @@ def test_migration_1_latest() -> None: assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2) assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0) assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0) - assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0) + assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0) sqlite3_cursor.close() +@pytest.mark.usefixtures("_create_v2_db") +@pytest.mark.usefixtures("_teardown") +def test_migration_2_latest() -> None: + nr_tables_ver_2 = 1 + table_ver_2_0 = "location" + nr_column_ver_2_location = 5 + nr_tables_ver_3 = 1 + table_ver_3_0 = "location" + nr_column_ver_3_location = 5 + + sqlite3_db = sqlite3.connect(DB_PATH_STR) + sqlite3_cursor = sqlite3_db.cursor() + sqlite3_cursor.execute("PRAGMA user_version") + assert sqlite3_cursor.fetchone()[0] == 2 # noqa: PLR2004 + sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';") + tables = sqlite3_cursor.fetchall() + assert len(tables) == nr_tables_ver_2 + assert tables[0][0] == table_ver_2_0 + sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})") + table_info_location = sqlite3_cursor.fetchall() + assert len(table_info_location) == nr_column_ver_2_location + assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1) + assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2) + assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0) + assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0) + assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0) + + location_recorder = LocationRecorder(db_path=DB_PATH_STR) + asyncio.run(location_recorder.create_db_engine()) + sqlite3_cursor = sqlite3_db.cursor() + sqlite3_cursor.execute("PRAGMA user_version") + assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION + sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';") + tables = sqlite3_cursor.fetchall() + assert len(tables) == nr_tables_ver_3 + sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_3_0})") + table_info_location = sqlite3_cursor.fetchall() + assert len(table_info_location) == nr_column_ver_3_location + assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1) + assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2) + assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0) + assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0) + assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0) + sqlite3_cursor.close() + + +@pytest.mark.usefixtures("_reset_event_loop") @pytest.mark.usefixtures("_teardown") def test_create_db() -> None: location_recorder = LocationRecorder(db_path=DB_PATH_STR) @@ -133,7 +214,8 @@ def test_create_db() -> None: assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION -@pytest.mark.usefixtures("_create_v2_db") +@pytest.mark.usefixtures("_reset_event_loop") +@pytest.mark.usefixtures("_create_latest_db") @pytest.mark.usefixtures("_teardown") def test_inser_location_utc() -> None: latitude = 1.0 @@ -165,7 +247,8 @@ def test_inser_location_utc() -> None: sqlite3_cursor.close() -@pytest.mark.usefixtures("_create_v2_db") +@pytest.mark.usefixtures("_reset_event_loop") +@pytest.mark.usefixtures("_create_latest_db") @pytest.mark.usefixtures("_teardown") def test_inser_location_other() -> None: latitude = 1.0 @@ -198,7 +281,8 @@ def test_inser_location_other() -> None: sqlite3_cursor.close() -@pytest.mark.usefixtures("_create_v2_db") +@pytest.mark.usefixtures("_reset_event_loop") +@pytest.mark.usefixtures("_create_latest_db") @pytest.mark.usefixtures("_teardown") def test_insert_location_now() -> None: latitude = 1.0