db version
This commit is contained in:
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@@ -8,8 +8,9 @@
|
|||||||
"editor.defaultFormatter": "charliermarsh.ruff"
|
"editor.defaultFormatter": "charliermarsh.ruff"
|
||||||
},
|
},
|
||||||
"python.testing.pytestArgs": [
|
"python.testing.pytestArgs": [
|
||||||
"src/"
|
"src/",
|
||||||
|
"${workspaceFolder}"
|
||||||
],
|
],
|
||||||
"python.testing.unittestEnabled": false,
|
"python.testing.unittestEnabled": false,
|
||||||
"python.testing.pytestEnabled": true
|
"python.testing.pytestEnabled": true,
|
||||||
}
|
}
|
||||||
@@ -1,11 +1 @@
|
|||||||
from fastapi import status
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
from src.main import app
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_homeassistant_status() -> None:
|
|
||||||
client = TestClient(app)
|
|
||||||
response = client.get("/homeassistant/status")
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.json() == {"Status": "Ok"}
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class HomeAssistant:
|
|||||||
if message.action == "record":
|
if message.action == "record":
|
||||||
location: dict[str, str] = ast.literal_eval(message.content)
|
location: dict[str, str] = ast.literal_eval(message.content)
|
||||||
await self._location_recorder.insert_location_now(
|
await self._location_recorder.insert_location_now(
|
||||||
people=location["person"],
|
person=location["person"],
|
||||||
location=LocationData(
|
location=LocationData(
|
||||||
latitude=float(location["latitude"]),
|
latitude=float(location["latitude"]),
|
||||||
longitude=float(location["longitude"]),
|
longitude=float(location["longitude"]),
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from sqlalchemy import INTEGER, REAL, TEXT, insert
|
from sqlalchemy import REAL, TEXT, insert, text
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
@@ -10,15 +10,9 @@ class Base(DeclarativeBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Version(Base):
|
|
||||||
__tablename__ = "version"
|
|
||||||
version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
|
||||||
version: Mapped[int] = mapped_column(type_=INTEGER)
|
|
||||||
|
|
||||||
|
|
||||||
class Location(Base):
|
class Location(Base):
|
||||||
__tablename__ = "location"
|
__tablename__ = "location"
|
||||||
people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||||
latitude: Mapped[float] = mapped_column(type_=REAL)
|
latitude: Mapped[float] = mapped_column(type_=REAL)
|
||||||
longitude: Mapped[float] = mapped_column(type_=REAL)
|
longitude: Mapped[float] = mapped_column(type_=REAL)
|
||||||
@@ -33,22 +27,29 @@ class LocationData:
|
|||||||
|
|
||||||
|
|
||||||
class LocationRecorder:
|
class LocationRecorder:
|
||||||
|
USER_VERSION = 2
|
||||||
|
|
||||||
def __init__(self, db_path: str) -> None:
|
def __init__(self, db_path: str) -> None:
|
||||||
self._db_path = "sqlite+aiosqlite:///" + db_path
|
self._db_path = "sqlite+aiosqlite:///" + db_path
|
||||||
|
|
||||||
async def create_db_engine(self) -> None:
|
async def create_db_engine(self) -> None:
|
||||||
self._engine = create_async_engine(self._db_path)
|
self._engine = create_async_engine(self._db_path)
|
||||||
async with self._engine.begin() as conn:
|
async with self._engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
user_version = await self._get_user_version(conn=conn)
|
||||||
|
if user_version == 0:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
await self._set_user_version(conn=conn, user_version=2)
|
||||||
|
if user_version != LocationRecorder.USER_VERSION:
|
||||||
|
await self._migrate(conn=conn)
|
||||||
|
|
||||||
async def dispose_db_engine(self) -> None:
|
async def dispose_db_engine(self) -> None:
|
||||||
await self._engine.dispose()
|
await self._engine.dispose()
|
||||||
|
|
||||||
async def insert_location(self, people: str, datetime: str, location: LocationData) -> None:
|
async def insert_location(self, person: str, datetime: str, location: LocationData) -> None:
|
||||||
async with self._engine.connect() as conn:
|
async with self._engine.connect() as conn:
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
insert(Location).values(
|
insert(Location).values(
|
||||||
people=people,
|
person=person,
|
||||||
datetime=datetime,
|
datetime=datetime,
|
||||||
latitude=location.latitude,
|
latitude=location.latitude,
|
||||||
longitude=location.longitude,
|
longitude=location.longitude,
|
||||||
@@ -58,7 +59,25 @@ class LocationRecorder:
|
|||||||
await conn.commit()
|
await conn.commit()
|
||||||
await conn.aclose()
|
await conn.aclose()
|
||||||
|
|
||||||
async def insert_location_now(self, people: str, location: LocationData) -> None:
|
async def insert_location_now(self, person: str, location: LocationData) -> None:
|
||||||
now = datetime.datetime.now(tz=datetime.UTC)
|
now = datetime.datetime.now(tz=datetime.UTC)
|
||||||
now_str = now.strftime("%Y-%m-%dT%H:%M:%S%z")
|
now_str = now.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||||
await self.insert_location(people, now_str, location)
|
await self.insert_location(person, now_str, location)
|
||||||
|
|
||||||
|
async def _get_user_version(self, conn: AsyncConnection) -> int:
|
||||||
|
return (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||||
|
|
||||||
|
async def _set_user_version(self, conn: AsyncConnection, user_version: int) -> None:
|
||||||
|
await conn.execute(text("PRAGMA user_version = " + str(user_version)))
|
||||||
|
|
||||||
|
async def _migrate(self, conn: AsyncConnection) -> None:
|
||||||
|
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||||
|
if user_version == 1:
|
||||||
|
await self._migrate_1_2(conn=conn)
|
||||||
|
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||||
|
|
||||||
|
async def _migrate_1_2(self, conn: AsyncConnection) -> None:
|
||||||
|
print("Location Recorder: migrate from db ver 1 to 2.")
|
||||||
|
await conn.execute(text("DROP TABLE version"))
|
||||||
|
await conn.execute(text("ALTER TABLE location RENAME people TO person"))
|
||||||
|
await self._set_user_version(conn=conn, user_version=2)
|
||||||
|
|||||||
130
src/util/tests/test_location_recorder.py
Normal file
130
src/util/tests/test_location_recorder.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import asyncio
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import INTEGER, REAL, TEXT, create_engine, text
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from src.util.location_recorder import LocationRecorder
|
||||||
|
|
||||||
|
DB_PATH = Path(__file__).resolve().parent / "test.db"
|
||||||
|
DB_PATH_STR = str(DB_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _teardown() -> any:
|
||||||
|
yield
|
||||||
|
DB_PATH.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _create_v1_db() -> None:
|
||||||
|
db = "sqlite:///" + DB_PATH_STR
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Version(Base):
|
||||||
|
__tablename__ = "version"
|
||||||
|
version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||||
|
version: Mapped[int] = mapped_column(type_=INTEGER)
|
||||||
|
|
||||||
|
class Location(Base):
|
||||||
|
__tablename__ = "location"
|
||||||
|
people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||||
|
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||||
|
latitude: Mapped[float] = mapped_column(type_=REAL)
|
||||||
|
longitude: Mapped[float] = mapped_column(type_=REAL)
|
||||||
|
altitude: Mapped[float] = mapped_column(type_=REAL)
|
||||||
|
|
||||||
|
engine = create_engine(db)
|
||||||
|
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(text("PRAGMA user_version = 1"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _create_v2_db() -> None:
|
||||||
|
db = "sqlite:///" + DB_PATH_STR
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Location(Base):
|
||||||
|
__tablename__ = "location"
|
||||||
|
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||||
|
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||||
|
latitude: Mapped[float] = mapped_column(type_=REAL)
|
||||||
|
longitude: Mapped[float] = mapped_column(type_=REAL)
|
||||||
|
altitude: Mapped[float] = mapped_column(type_=REAL)
|
||||||
|
|
||||||
|
engine = create_engine(db)
|
||||||
|
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_create_v1_db")
|
||||||
|
@pytest.mark.usefixtures("_teardown")
|
||||||
|
def test_migration_1_latest() -> None:
|
||||||
|
nr_tables_ver_1 = 2
|
||||||
|
table_ver_1_0 = "version"
|
||||||
|
nr_column_ver_1_version = 2
|
||||||
|
table_ver_1_1 = "location"
|
||||||
|
nr_column_ver_1_location = 5
|
||||||
|
nr_tables_ver_2 = 1
|
||||||
|
table_ver_2_0 = "location"
|
||||||
|
|
||||||
|
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||||
|
sqlite3_cursor = sqlite3_db.cursor()
|
||||||
|
sqlite3_cursor.execute("PRAGMA user_version")
|
||||||
|
assert sqlite3_cursor.fetchone()[0] == 1
|
||||||
|
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||||
|
tables = sqlite3_cursor.fetchall()
|
||||||
|
assert len(tables) == nr_tables_ver_1
|
||||||
|
assert tables[0][0] == table_ver_1_0
|
||||||
|
assert tables[1][0] == table_ver_1_1
|
||||||
|
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_0})")
|
||||||
|
table_info_version = sqlite3_cursor.fetchall()
|
||||||
|
assert len(table_info_version) == nr_column_ver_1_version
|
||||||
|
assert table_info_version[0] == (0, "version_type", "TEXT", 1, None, 1)
|
||||||
|
assert table_info_version[1] == (1, "version", "INTEGER", 1, None, 0)
|
||||||
|
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_1})")
|
||||||
|
table_info_location = sqlite3_cursor.fetchall()
|
||||||
|
assert len(table_info_location) == nr_column_ver_1_location
|
||||||
|
assert table_info_location[0] == (0, "people", "TEXT", 1, None, 1)
|
||||||
|
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||||
|
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||||
|
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||||
|
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
|
||||||
|
|
||||||
|
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||||
|
asyncio.run(location_recorder.create_db_engine())
|
||||||
|
sqlite3_cursor = sqlite3_db.cursor()
|
||||||
|
sqlite3_cursor.execute("PRAGMA user_version")
|
||||||
|
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
|
||||||
|
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||||
|
tables = sqlite3_cursor.fetchall()
|
||||||
|
assert len(tables) == nr_tables_ver_2
|
||||||
|
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
|
||||||
|
table_info_location = sqlite3_cursor.fetchall()
|
||||||
|
assert len(table_info_location) == nr_column_ver_1_location
|
||||||
|
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
|
||||||
|
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||||
|
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||||
|
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||||
|
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
|
||||||
|
sqlite3_cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_teardown")
|
||||||
|
def test_create_db() -> None:
|
||||||
|
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||||
|
asyncio.run(location_recorder.create_db_engine())
|
||||||
|
location_recorder.dispose_db_engine()
|
||||||
|
assert DB_PATH.exists()
|
||||||
|
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||||
|
sqlite3_cursor = sqlite3_db.cursor()
|
||||||
|
sqlite3_cursor.execute("PRAGMA user_version")
|
||||||
|
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from util.ticktick import TickTick
|
|
||||||
|
|
||||||
|
|
||||||
def test_ticktick_begin_auth() -> None:
|
|
||||||
auth_url = TickTick.begin_auth()
|
|
||||||
Reference in New Issue
Block a user