From b8faf975bf3fddf38a1e1a0efda3b9554d919d12 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 28 Aug 2024 22:53:58 +0200 Subject: [PATCH] db version --- .vscode/settings.json | 5 +- src/tests/test_main.py | 10 -- src/util/homeassistant.py | 2 +- src/util/location_recorder.py | 47 +++++--- src/util/tests/test_location_recorder.py | 130 +++++++++++++++++++++++ src/util/tests/test_ticktick.py | 7 -- 6 files changed, 167 insertions(+), 34 deletions(-) create mode 100644 src/util/tests/test_location_recorder.py delete mode 100644 src/util/tests/test_ticktick.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 97d59ff..e6b94ee 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,8 +8,9 @@ "editor.defaultFormatter": "charliermarsh.ruff" }, "python.testing.pytestArgs": [ - "src/" + "src/", + "${workspaceFolder}" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, } \ No newline at end of file diff --git a/src/tests/test_main.py b/src/tests/test_main.py index 3480612..8b13789 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -1,11 +1 @@ -from fastapi import status -from fastapi.testclient import TestClient -from src.main import app - - -def test_get_homeassistant_status() -> None: - client = TestClient(app) - response = client.get("/homeassistant/status") - assert response.status_code == status.HTTP_200_OK - assert response.json() == {"Status": "Ok"} diff --git a/src/util/homeassistant.py b/src/util/homeassistant.py index d43b54a..a2498bd 100644 --- a/src/util/homeassistant.py +++ b/src/util/homeassistant.py @@ -44,7 +44,7 @@ class HomeAssistant: if message.action == "record": location: dict[str, str] = ast.literal_eval(message.content) await self._location_recorder.insert_location_now( - people=location["person"], + person=location["person"], location=LocationData( latitude=float(location["latitude"]), longitude=float(location["longitude"]), diff --git a/src/util/location_recorder.py b/src/util/location_recorder.py index f6158a8..5364769 100644 --- a/src/util/location_recorder.py +++ b/src/util/location_recorder.py @@ -1,8 +1,8 @@ import datetime from dataclasses import dataclass -from sqlalchemy import INTEGER, REAL, TEXT, insert -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy import REAL, TEXT, insert, text +from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -10,15 +10,9 @@ class Base(DeclarativeBase): pass -class Version(Base): - __tablename__ = "version" - version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) - version: Mapped[int] = mapped_column(type_=INTEGER) - - class Location(Base): __tablename__ = "location" - people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) + person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) latitude: Mapped[float] = mapped_column(type_=REAL) longitude: Mapped[float] = mapped_column(type_=REAL) @@ -33,22 +27,29 @@ class LocationData: class LocationRecorder: + USER_VERSION = 2 + def __init__(self, db_path: str) -> None: self._db_path = "sqlite+aiosqlite:///" + db_path async def create_db_engine(self) -> None: self._engine = create_async_engine(self._db_path) async with self._engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) + user_version = await self._get_user_version(conn=conn) + if user_version == 0: + await conn.run_sync(Base.metadata.create_all) + await self._set_user_version(conn=conn, user_version=2) + if user_version != LocationRecorder.USER_VERSION: + await self._migrate(conn=conn) async def dispose_db_engine(self) -> None: await self._engine.dispose() - async def insert_location(self, people: str, datetime: str, location: LocationData) -> None: + async def insert_location(self, person: str, datetime: str, location: LocationData) -> None: async with self._engine.connect() as conn: await conn.execute( insert(Location).values( - people=people, + person=person, datetime=datetime, latitude=location.latitude, longitude=location.longitude, @@ -58,7 +59,25 @@ class LocationRecorder: await conn.commit() await conn.aclose() - async def insert_location_now(self, people: str, location: LocationData) -> None: + async def insert_location_now(self, person: str, location: LocationData) -> None: now = datetime.datetime.now(tz=datetime.UTC) now_str = now.strftime("%Y-%m-%dT%H:%M:%S%z") - await self.insert_location(people, now_str, location) + await self.insert_location(person, now_str, location) + + async def _get_user_version(self, conn: AsyncConnection) -> int: + return (await conn.execute(text("PRAGMA user_version"))).first()[0] + + async def _set_user_version(self, conn: AsyncConnection, user_version: int) -> None: + await conn.execute(text("PRAGMA user_version = " + str(user_version))) + + async def _migrate(self, conn: AsyncConnection) -> None: + user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0] + if user_version == 1: + await self._migrate_1_2(conn=conn) + user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0] + + async def _migrate_1_2(self, conn: AsyncConnection) -> None: + print("Location Recorder: migrate from db ver 1 to 2.") + await conn.execute(text("DROP TABLE version")) + await conn.execute(text("ALTER TABLE location RENAME people TO person")) + await self._set_user_version(conn=conn, user_version=2) diff --git a/src/util/tests/test_location_recorder.py b/src/util/tests/test_location_recorder.py new file mode 100644 index 0000000..f3cd8f6 --- /dev/null +++ b/src/util/tests/test_location_recorder.py @@ -0,0 +1,130 @@ +import asyncio +import sqlite3 +from pathlib import Path + +import pytest +from sqlalchemy import INTEGER, REAL, TEXT, create_engine, text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +from src.util.location_recorder import LocationRecorder + +DB_PATH = Path(__file__).resolve().parent / "test.db" +DB_PATH_STR = str(DB_PATH) + + +@pytest.fixture +def _teardown() -> any: + yield + DB_PATH.unlink() + + +@pytest.fixture +def _create_v1_db() -> None: + db = "sqlite:///" + DB_PATH_STR + + class Base(DeclarativeBase): + pass + + class Version(Base): + __tablename__ = "version" + version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) + version: Mapped[int] = mapped_column(type_=INTEGER) + + class Location(Base): + __tablename__ = "location" + people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) + datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) + latitude: Mapped[float] = mapped_column(type_=REAL) + longitude: Mapped[float] = mapped_column(type_=REAL) + altitude: Mapped[float] = mapped_column(type_=REAL) + + engine = create_engine(db) + + Base.metadata.create_all(engine) + with engine.begin() as conn: + conn.execute(text("PRAGMA user_version = 1")) + + +@pytest.fixture +def _create_v2_db() -> None: + db = "sqlite:///" + DB_PATH_STR + + class Base(DeclarativeBase): + pass + + class Location(Base): + __tablename__ = "location" + person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) + datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True) + latitude: Mapped[float] = mapped_column(type_=REAL) + longitude: Mapped[float] = mapped_column(type_=REAL) + altitude: Mapped[float] = mapped_column(type_=REAL) + + engine = create_engine(db) + + Base.metadata.create_all(engine) + + +@pytest.mark.usefixtures("_create_v1_db") +@pytest.mark.usefixtures("_teardown") +def test_migration_1_latest() -> None: + nr_tables_ver_1 = 2 + table_ver_1_0 = "version" + nr_column_ver_1_version = 2 + table_ver_1_1 = "location" + nr_column_ver_1_location = 5 + nr_tables_ver_2 = 1 + table_ver_2_0 = "location" + + sqlite3_db = sqlite3.connect(DB_PATH_STR) + sqlite3_cursor = sqlite3_db.cursor() + sqlite3_cursor.execute("PRAGMA user_version") + assert sqlite3_cursor.fetchone()[0] == 1 + sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';") + tables = sqlite3_cursor.fetchall() + assert len(tables) == nr_tables_ver_1 + assert tables[0][0] == table_ver_1_0 + assert tables[1][0] == table_ver_1_1 + sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_0})") + table_info_version = sqlite3_cursor.fetchall() + assert len(table_info_version) == nr_column_ver_1_version + assert table_info_version[0] == (0, "version_type", "TEXT", 1, None, 1) + assert table_info_version[1] == (1, "version", "INTEGER", 1, None, 0) + sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_1})") + table_info_location = sqlite3_cursor.fetchall() + assert len(table_info_location) == nr_column_ver_1_location + assert table_info_location[0] == (0, "people", "TEXT", 1, None, 1) + assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2) + assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0) + assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0) + assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0) + + location_recorder = LocationRecorder(db_path=DB_PATH_STR) + asyncio.run(location_recorder.create_db_engine()) + sqlite3_cursor = sqlite3_db.cursor() + sqlite3_cursor.execute("PRAGMA user_version") + assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION + sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';") + tables = sqlite3_cursor.fetchall() + assert len(tables) == nr_tables_ver_2 + sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})") + table_info_location = sqlite3_cursor.fetchall() + assert len(table_info_location) == nr_column_ver_1_location + assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1) + assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2) + assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0) + assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0) + assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0) + sqlite3_cursor.close() + + +@pytest.mark.usefixtures("_teardown") +def test_create_db() -> None: + location_recorder = LocationRecorder(db_path=DB_PATH_STR) + asyncio.run(location_recorder.create_db_engine()) + location_recorder.dispose_db_engine() + assert DB_PATH.exists() + sqlite3_db = sqlite3.connect(DB_PATH_STR) + sqlite3_cursor = sqlite3_db.cursor() + sqlite3_cursor.execute("PRAGMA user_version") + assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION diff --git a/src/util/tests/test_ticktick.py b/src/util/tests/test_ticktick.py deleted file mode 100644 index 96e9c3e..0000000 --- a/src/util/tests/test_ticktick.py +++ /dev/null @@ -1,7 +0,0 @@ -import pytest - -from util.ticktick import TickTick - - -def test_ticktick_begin_auth() -> None: - auth_url = TickTick.begin_auth()