from dataclasses import dataclass from datetime import UTC, datetime from sqlalchemy import REAL, TEXT, insert, text from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class Base(DeclarativeBase): pass class Location(Base): __tablename__ = "location" person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) latitude: Mapped[float] = mapped_column(type_=REAL) longitude: Mapped[float] = mapped_column(type_=REAL) altitude: Mapped[float] = mapped_column(type_=REAL) @dataclass class LocationData: latitude: float longitude: float altitude: float class LocationRecorder: USER_VERSION = 2 def __init__(self, db_path: str) -> None: self._db_path = "sqlite+aiosqlite:///" + db_path async def create_db_engine(self) -> None: self._engine = create_async_engine(self._db_path) async with self._engine.begin() as conn: user_version = await self._get_user_version(conn=conn) if user_version == 0: await conn.run_sync(Base.metadata.create_all) await self._set_user_version(conn=conn, user_version=2) if user_version != LocationRecorder.USER_VERSION: await self._migrate(conn=conn) async def dispose_db_engine(self) -> None: await self._engine.dispose() async def insert_location(self, person: str, date_time: datetime, location: LocationData) -> None: if date_time.tzinfo != UTC: date_time = date_time.astimezone(UTC) date_time_str = date_time.strftime("%Y-%m-%dT%H:%M:%S%z") async with self._engine.connect() as conn: await conn.execute( insert(Location).values( person=person, datetime=date_time_str, latitude=location.latitude, longitude=location.longitude, altitude=location.altitude, ), ) await conn.commit() await conn.aclose() async def insert_location_now(self, person: str, location: LocationData) -> None: now_utc = datetime.now(tz=UTC) await self.insert_location(person, now_utc, location) async def _get_user_version(self, conn: AsyncConnection) -> int: return (await conn.execute(text("PRAGMA user_version"))).first()[0] async def _set_user_version(self, conn: AsyncConnection, user_version: int) -> None: await conn.execute(text("PRAGMA user_version = " + str(user_version))) async def _migrate(self, conn: AsyncConnection) -> None: user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0] if user_version == 1: await self._migrate_1_2(conn=conn) user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0] async def _migrate_1_2(self, conn: AsyncConnection) -> None: print("Location Recorder: migrate from db ver 1 to 2.") await conn.execute(text("DROP TABLE version")) await conn.execute(text("ALTER TABLE location RENAME people TO person")) await self._set_user_version(conn=conn, user_version=2)