diff --git a/src/helper/location_recorder/__init__.py b/src/helper/location_recorder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/util/location_recorder.py b/src/util/location_recorder.py index 5364769..16e4889 100644 --- a/src/util/location_recorder.py +++ b/src/util/location_recorder.py @@ -1,5 +1,5 @@ -import datetime from dataclasses import dataclass +from datetime import UTC, datetime from sqlalchemy import REAL, TEXT, insert, text from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine @@ -45,12 +45,15 @@ class LocationRecorder: async def dispose_db_engine(self) -> None: await self._engine.dispose() - async def insert_location(self, person: str, datetime: str, location: LocationData) -> None: + async def insert_location(self, person: str, date_time: datetime, location: LocationData) -> None: + if date_time.tzinfo != UTC: + date_time = date_time.astimezone(UTC) + date_time_str = date_time.strftime("%Y-%m-%dT%H:%M:%S%z") async with self._engine.connect() as conn: await conn.execute( insert(Location).values( person=person, - datetime=datetime, + datetime=date_time_str, latitude=location.latitude, longitude=location.longitude, altitude=location.altitude, @@ -60,9 +63,8 @@ class LocationRecorder: await conn.aclose() async def insert_location_now(self, person: str, location: LocationData) -> None: - now = datetime.datetime.now(tz=datetime.UTC) - now_str = now.strftime("%Y-%m-%dT%H:%M:%S%z") - await self.insert_location(person, now_str, location) + now_utc = datetime.now(tz=UTC) + await self.insert_location(person, now_utc, location) async def _get_user_version(self, conn: AsyncConnection) -> int: return (await conn.execute(text("PRAGMA user_version"))).first()[0] diff --git a/src/util/tests/test_location_recorder.py b/src/util/tests/test_location_recorder.py index f3cd8f6..0775a18 100644 --- a/src/util/tests/test_location_recorder.py +++ b/src/util/tests/test_location_recorder.py @@ -1,12 +1,14 @@ import asyncio import sqlite3 +from datetime import UTC, datetime from pathlib import Path +from zoneinfo import ZoneInfo import pytest from sqlalchemy import INTEGER, REAL, TEXT, create_engine, text from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from src.util.location_recorder import LocationRecorder +from src.util.location_recorder import LocationData, LocationRecorder DB_PATH = Path(__file__).resolve().parent / "test.db" DB_PATH_STR = str(DB_PATH) @@ -121,10 +123,108 @@ def test_migration_1_latest() -> None: @pytest.mark.usefixtures("_teardown") def test_create_db() -> None: location_recorder = LocationRecorder(db_path=DB_PATH_STR) - asyncio.run(location_recorder.create_db_engine()) - location_recorder.dispose_db_engine() + event_loop = asyncio.get_event_loop() + event_loop.run_until_complete(location_recorder.create_db_engine()) + event_loop.run_until_complete(location_recorder.dispose_db_engine()) assert DB_PATH.exists() sqlite3_db = sqlite3.connect(DB_PATH_STR) sqlite3_cursor = sqlite3_db.cursor() sqlite3_cursor.execute("PRAGMA user_version") assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION + + +@pytest.mark.usefixtures("_create_v2_db") +@pytest.mark.usefixtures("_teardown") +def test_inser_location_utc() -> None: + latitude = 1.0 + longitude = 2.0 + altitude = 3.0 + person = "test_person" + date_time = datetime.now(tz=UTC) + location_recorder = LocationRecorder(db_path=DB_PATH_STR) + event_loop = asyncio.get_event_loop() + event_loop.run_until_complete(location_recorder.create_db_engine()) + location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude) + event_loop.run_until_complete( + location_recorder.insert_location( + person=person, + date_time=date_time, + location=location_data, + ), + ) + event_loop.run_until_complete(location_recorder.dispose_db_engine()) + sqlite3_db = sqlite3.connect(DB_PATH_STR) + sqlite3_cursor = sqlite3_db.cursor() + sqlite3_cursor.execute("SELECT * FROM location") + location = sqlite3_cursor.fetchone() + assert location[0] == person + assert location[1] == date_time.strftime("%Y-%m-%dT%H:%M:%S%z") + assert location[2] == latitude + assert location[3] == longitude + assert location[4] == altitude + sqlite3_cursor.close() + + +@pytest.mark.usefixtures("_create_v2_db") +@pytest.mark.usefixtures("_teardown") +def test_inser_location_other() -> None: + latitude = 1.0 + longitude = 2.0 + altitude = 3.0 + person = "test_person" + tz = ZoneInfo("Asia/Shanghai") + date_time = datetime.now(tz=tz) + location_recorder = LocationRecorder(db_path=DB_PATH_STR) + event_loop = asyncio.get_event_loop() + event_loop.run_until_complete(location_recorder.create_db_engine()) + location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude) + event_loop.run_until_complete( + location_recorder.insert_location( + person=person, + date_time=date_time, + location=location_data, + ), + ) + event_loop.run_until_complete(location_recorder.dispose_db_engine()) + sqlite3_db = sqlite3.connect(DB_PATH_STR) + sqlite3_cursor = sqlite3_db.cursor() + sqlite3_cursor.execute("SELECT * FROM location") + location = sqlite3_cursor.fetchone() + assert location[0] == person + assert location[1] == date_time.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z") + assert location[2] == latitude + assert location[3] == longitude + assert location[4] == altitude + sqlite3_cursor.close() + + +@pytest.mark.usefixtures("_create_v2_db") +@pytest.mark.usefixtures("_teardown") +def test_insert_location_now() -> None: + latitude = 1.0 + longitude = 2.0 + altitude = 3.0 + person = "test_person" + date_time = datetime.now(tz=UTC) + location_recorder = LocationRecorder(db_path=DB_PATH_STR) + event_loop = asyncio.get_event_loop() + event_loop.run_until_complete(location_recorder.create_db_engine()) + location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude) + event_loop.run_until_complete( + location_recorder.insert_location_now( + person=person, + location=location_data, + ), + ) + event_loop.run_until_complete(location_recorder.dispose_db_engine()) + sqlite3_db = sqlite3.connect(DB_PATH_STR) + sqlite3_cursor = sqlite3_db.cursor() + sqlite3_cursor.execute("SELECT * FROM location") + location = sqlite3_cursor.fetchone() + assert location[0] == person + date_time_act = datetime.strptime(location[1], "%Y-%m-%dT%H:%M:%S%z") + assert date_time.date() == date_time_act.date() + assert location[2] == latitude + assert location[3] == longitude + assert location[4] == altitude + sqlite3_cursor.close()