Start with go version
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,71 +0,0 @@
|
||||
import ast
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.config import Config
|
||||
from src.util.location_recorder import LocationData, LocationRecorder
|
||||
from src.util.ticktick import TickTick
|
||||
|
||||
|
||||
class HomeAssistant:
|
||||
class Message(BaseModel):
|
||||
target: str
|
||||
action: str
|
||||
content: str
|
||||
|
||||
def __init__(self, ticktick: TickTick, location_recorder: LocationRecorder) -> None:
|
||||
self._ticktick = ticktick
|
||||
self._location_recorder = location_recorder
|
||||
|
||||
async def process_message(self, message: Message) -> dict[str, str]:
|
||||
if message.target == "ticktick":
|
||||
return await self._process_ticktick_message(message=message)
|
||||
if message.target == "location_recorder":
|
||||
return await self._process_location(message=message)
|
||||
return {"Status": "Unknown target"}
|
||||
|
||||
async def trigger_webhook(self, payload: dict[str, str], webhook_id: str) -> None:
|
||||
token: str = Config.get_env("HOMEASSISTANT_TOKEN")
|
||||
webhook_url: str = Config.get_env("HOMEASSISTANT_URL") + "/api/webhook/" + webhook_id
|
||||
headers: dict[str, str] = {"Authorization": f"Bearer {token}"}
|
||||
await httpx.AsyncClient().post(webhook_url, json=payload, headers=headers)
|
||||
|
||||
async def _process_ticktick_message(self, message: Message) -> dict[str, str]:
|
||||
if message.action == "create_shopping_list":
|
||||
return await self._create_shopping_list(content=message.content)
|
||||
if message.action == "create_action_task":
|
||||
return await self._create_action_task(content=message.content)
|
||||
|
||||
return {"Status": "Unknown action"}
|
||||
|
||||
async def _process_location(self, message: Message) -> dict[str, str]:
|
||||
if message.action == "record":
|
||||
location: dict[str, str] = ast.literal_eval(message.content)
|
||||
await self._location_recorder.insert_location_now(
|
||||
person=location["person"],
|
||||
location=LocationData(
|
||||
latitude=float(location["latitude"]),
|
||||
longitude=float(location["longitude"]),
|
||||
altitude=float(location["altitude"]),
|
||||
),
|
||||
)
|
||||
return {"Status": "Location recorded"}
|
||||
return {"Status": "Unknown action"}
|
||||
|
||||
async def _create_shopping_list(self, content: str) -> dict[str, str]:
|
||||
project_id = Config.get_env("TICKTICK_SHOPPING_LIST")
|
||||
item: dict[str, str] = ast.literal_eval(content)
|
||||
task = TickTick.Task(projectId=project_id, title=item["item"])
|
||||
return await self._ticktick.create_task(task=task)
|
||||
|
||||
async def _create_action_task(self, content: str) -> dict[str, str]:
|
||||
detail: dict[str, str] = ast.literal_eval(content)
|
||||
project_id = Config.get_env("TICKTICK_HOME_TASK_LIST")
|
||||
due_hour = detail["due_hour"]
|
||||
due = datetime.now(tz=datetime.now().astimezone().tzinfo) + timedelta(hours=due_hour)
|
||||
due = (due + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
due = due.astimezone(timezone.utc)
|
||||
task = TickTick.Task(projectId=project_id, title=detail["action"], dueDate=TickTick.datetime_to_ticktick_format(due))
|
||||
return await self._ticktick.create_task(task=task)
|
||||
@@ -1,120 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import REAL, TEXT, insert, text
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Location(Base):
|
||||
__tablename__ = "location"
|
||||
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
|
||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
|
||||
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
|
||||
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
|
||||
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocationData:
|
||||
latitude: float
|
||||
longitude: float
|
||||
altitude: float
|
||||
|
||||
|
||||
class LocationRecorder:
|
||||
USER_VERSION = 3
|
||||
|
||||
def __init__(self, db_path: str) -> None:
|
||||
self._db_path = "sqlite+aiosqlite:///" + db_path
|
||||
|
||||
async def create_db_engine(self) -> None:
|
||||
self._engine = create_async_engine(self._db_path)
|
||||
async with self._engine.begin() as conn:
|
||||
user_version = await self._get_user_version(conn=conn)
|
||||
if user_version == 0:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await self._set_user_version(conn=conn, user_version=2)
|
||||
if user_version != LocationRecorder.USER_VERSION:
|
||||
await self._migrate(conn=conn)
|
||||
|
||||
async def dispose_db_engine(self) -> None:
|
||||
await self._engine.dispose()
|
||||
|
||||
async def insert_location(self, person: str, date_time: datetime, location: LocationData) -> None:
|
||||
if date_time.tzinfo != UTC:
|
||||
date_time = date_time.astimezone(UTC)
|
||||
date_time_str = date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.execute(
|
||||
insert(Location)
|
||||
.prefix_with("OR IGNORE")
|
||||
.values(
|
||||
person=person,
|
||||
datetime=date_time_str,
|
||||
latitude=location.latitude,
|
||||
longitude=location.longitude,
|
||||
altitude=location.altitude,
|
||||
),
|
||||
)
|
||||
|
||||
async def insert_locations(self, person: str, locations: dict[datetime, LocationData]) -> None:
|
||||
async with self._engine.begin() as conn:
|
||||
for k, v in locations.items():
|
||||
dt = k
|
||||
if k.tzinfo != UTC:
|
||||
dt = k.astimezone(UTC)
|
||||
date_time_str = dt.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
await conn.execute(
|
||||
insert(Location)
|
||||
.prefix_with("OR IGNORE")
|
||||
.values(
|
||||
person=person,
|
||||
datetime=date_time_str,
|
||||
latitude=v.latitude,
|
||||
longitude=v.longitude,
|
||||
altitude=v.altitude,
|
||||
),
|
||||
)
|
||||
|
||||
async def insert_location_now(self, person: str, location: LocationData) -> None:
|
||||
now_utc = datetime.now(tz=UTC)
|
||||
await self.insert_location(person, now_utc, location)
|
||||
|
||||
async def _get_user_version(self, conn: AsyncConnection) -> int:
|
||||
return (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||
|
||||
async def _set_user_version(self, conn: AsyncConnection, user_version: int) -> None:
|
||||
await conn.execute(text("PRAGMA user_version = " + str(user_version)))
|
||||
|
||||
async def _migrate(self, conn: AsyncConnection) -> None:
|
||||
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||
if user_version == 1:
|
||||
await self._migrate_1_2(conn=conn)
|
||||
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||
if user_version == 2: # noqa: PLR2004
|
||||
await self._migrate_2_3(conn=conn)
|
||||
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||
|
||||
async def _migrate_1_2(self, conn: AsyncConnection) -> None:
|
||||
print("Location Recorder: migrate from db ver 1 to 2.")
|
||||
await conn.execute(text("DROP TABLE version"))
|
||||
await conn.execute(text("ALTER TABLE location RENAME people TO person"))
|
||||
await self._set_user_version(conn=conn, user_version=2)
|
||||
|
||||
async def _migrate_2_3(self, conn: AsyncConnection) -> None:
|
||||
print("Location Recorder: migrate from db ver 2 to 3.")
|
||||
await conn.execute(text("ALTER TABLE location RENAME TO location_old"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await conn.execute(
|
||||
text("""
|
||||
INSERT INTO location (person, datetime, latitude, longitude, altitude)
|
||||
SELECT person, datetime, latitude, longitude, altitude FROM location_old;
|
||||
"""),
|
||||
)
|
||||
await conn.execute(text("DROP TABLE location_old"))
|
||||
await self._set_user_version(conn=conn, user_version=3)
|
||||
@@ -1,73 +0,0 @@
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi_mqtt import FastMQTT, MQTTConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MQTTSubscription:
|
||||
topic: str
|
||||
callback: callable
|
||||
subscribed: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class MQTTPendingMessage:
|
||||
topic: str
|
||||
payload: dict
|
||||
retain: bool
|
||||
|
||||
|
||||
class MQTT:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106
|
||||
self._mqtt = FastMQTT(config=self._mqtt_config, client_id="home_automation_backend")
|
||||
self._mqtt.mqtt_handlers.user_connect_handler = self.on_connect
|
||||
self._mqtt.mqtt_handlers.user_message_handler = self.on_message
|
||||
self._connected = False
|
||||
self._subscribed_topic: dict[str, MQTTSubscription] = {}
|
||||
self._queued_message: queue.Queue[MQTTPendingMessage] = queue.Queue()
|
||||
|
||||
async def start(self) -> None:
|
||||
print("MQTT Starting...")
|
||||
await self._mqtt.mqtt_startup()
|
||||
|
||||
async def stop(self) -> None:
|
||||
print("MQTT Stopping...")
|
||||
await self._mqtt.mqtt_shutdown()
|
||||
|
||||
def on_connect(self, client, flags, rc, properties) -> None: # noqa: ANN001, ARG002
|
||||
print("Connected")
|
||||
self._connected = True
|
||||
while not self._queued_message.empty():
|
||||
msg = self._queued_message.get(block=False)
|
||||
self.publish(msg.topic, msg.payload, retain=msg.retain)
|
||||
for topic, subscription in self._subscribed_topic.items():
|
||||
if subscription.subscribed is False:
|
||||
self.subscribe(topic, subscription.callback)
|
||||
|
||||
async def on_message(self, client, topic: str, payload: bytes, qos: int, properties: any) -> any: # noqa: ANN001, ARG002
|
||||
print("On message")
|
||||
if topic in self._subscribed_topic and self._subscribed_topic[topic].callback is not None:
|
||||
await self._subscribed_topic[topic].callback(payload)
|
||||
|
||||
def subscribe(self, topic: str, callback: callable) -> None:
|
||||
if self._connected:
|
||||
print("Subscribe to topic: ", topic)
|
||||
self._mqtt.client.subscribe(topic)
|
||||
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=True)
|
||||
else:
|
||||
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=False)
|
||||
|
||||
def publish(self, topic: str, payload: dict, *, retain: bool) -> None:
|
||||
if self._connected:
|
||||
self._mqtt.publish(topic, payload=payload, retain=retain)
|
||||
else:
|
||||
self._queued_message.put(MQTTPendingMessage(topic, payload, retain=retain))
|
||||
@@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
from notion_client import AsyncClient as Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class Text:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RichText:
|
||||
type: str
|
||||
href: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RichTextText(RichText):
|
||||
type: str = "text"
|
||||
text: Text = field(default_factory=lambda: Text(content=""))
|
||||
|
||||
|
||||
class NotionAsync:
|
||||
def __init__(self, token: str) -> None:
|
||||
self._client = Client(auth=token)
|
||||
|
||||
def update_token(self, token: str) -> None:
|
||||
self._client.aclose()
|
||||
self._client = Client(auth=token)
|
||||
|
||||
async def get_block(self, block_id: str) -> dict:
|
||||
return await self._client.blocks.retrieve(block_id=block_id)
|
||||
|
||||
async def get_block_children(self, block_id: str, start_cursor: str | None = None, page_size: int = 100) -> dict:
|
||||
return await self._client.blocks.children.list(
|
||||
block_id=block_id,
|
||||
start_cursor=start_cursor,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
async def block_is_table(self, block_id: str) -> bool:
|
||||
block: dict = await self.get_block(block_id=block_id)
|
||||
return block["type"] == "table"
|
||||
|
||||
async def get_table_width(self, table_id: str) -> int:
|
||||
table = await self._client.blocks.retrieve(block_id=table_id)
|
||||
return table["table"]["table_width"]
|
||||
|
||||
async def append_table_row_text(self, table_id: str, text_list: list[str], after: str | None = None) -> None:
|
||||
cells: list[RichText] = []
|
||||
for content in text_list:
|
||||
cells.append([asdict(RichTextText(text=Text(content)))]) # noqa: PERF401
|
||||
await self.append_table_row(table_id=table_id, cells=cells, after=after)
|
||||
|
||||
async def append_table_row(self, table_id: str, cells: list[RichText], after: str | None = None) -> None:
|
||||
if not await self.block_is_table(table_id):
|
||||
return
|
||||
table_width = await self.get_table_width(table_id=table_id)
|
||||
if table_width != len(cells):
|
||||
return
|
||||
children = [
|
||||
{
|
||||
"object": "block",
|
||||
"type": "table_row",
|
||||
"table_row": {
|
||||
"cells": cells,
|
||||
},
|
||||
},
|
||||
]
|
||||
if after is None:
|
||||
await self._client.blocks.children.append(
|
||||
block_id=table_id,
|
||||
children=children,
|
||||
)
|
||||
else:
|
||||
await self._client.blocks.children.append(
|
||||
block_id=table_id,
|
||||
children=children,
|
||||
after=after,
|
||||
)
|
||||
62
src/util/notion/notion.go
Normal file
62
src/util/notion/notion.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package notion
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jomei/notionapi"
|
||||
)
|
||||
|
||||
var client *notionapi.Client
|
||||
var authToken string
|
||||
|
||||
func Init(token string) {
|
||||
authToken = token
|
||||
client = notionapi.NewClient(notionapi.Token(token))
|
||||
}
|
||||
|
||||
func GetClient() *notionapi.Client {
|
||||
return client
|
||||
}
|
||||
|
||||
func AGetBlock(blockId string) chan struct {
|
||||
Block notionapi.Block
|
||||
Error error
|
||||
} {
|
||||
retval := make(chan struct {
|
||||
Block notionapi.Block
|
||||
Error error
|
||||
})
|
||||
go func() {
|
||||
block, err := client.Block.Get(context.Background(), notionapi.BlockID(blockId))
|
||||
retval <- struct {
|
||||
Block notionapi.Block
|
||||
Error error
|
||||
}{block, err}
|
||||
}()
|
||||
return retval
|
||||
}
|
||||
|
||||
func AGetBlockChildren(blockId string, startCursor string, pageSize int) chan struct {
|
||||
Children []notionapi.Block
|
||||
NextCursor string
|
||||
HasMore bool
|
||||
Error error
|
||||
} {
|
||||
retval := make(chan struct {
|
||||
Children []notionapi.Block
|
||||
NextCursor string
|
||||
HasMore bool
|
||||
Error error
|
||||
})
|
||||
pagination := notionapi.Pagination{StartCursor: notionapi.Cursor(startCursor), PageSize: pageSize}
|
||||
go func() {
|
||||
children, err := client.Block.GetChildren(context.Background(), notionapi.BlockID(blockId), &pagination)
|
||||
retval <- struct {
|
||||
Children []notionapi.Block
|
||||
NextCursor string
|
||||
HasMore bool
|
||||
Error error
|
||||
}{children.Results, children.NextCursor, children.HasMore, err}
|
||||
}()
|
||||
return retval
|
||||
}
|
||||
27
src/util/notion/notion_test.go
Normal file
27
src/util/notion/notion_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package notion_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/t-liu93/home-automation-backend/util/notion"
|
||||
)
|
||||
|
||||
func TestGetBlockBasic(t *testing.T) {
|
||||
notion.Init("")
|
||||
block := <-notion.AGetBlock("")
|
||||
actObj := block.Block.GetObject()
|
||||
expObj := "block"
|
||||
if string(actObj) != expObj {
|
||||
t.Errorf("Expected %s, but got %s", expObj, actObj)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBlockChildrenBasic(t *testing.T) {
|
||||
notion.Init("")
|
||||
blockChildren := <-notion.AGetBlockChildren("", "", 100)
|
||||
actLen := len(blockChildren.Children)
|
||||
expLen := 100
|
||||
if actLen != expLen {
|
||||
t.Errorf("Expected %d, but got %d", expLen, actLen)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,354 +0,0 @@
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import INTEGER, REAL, TEXT, create_engine, text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from src.util.location_recorder import LocationData, LocationRecorder
|
||||
|
||||
DB_PATH = Path(__file__).resolve().parent / "test.db"
|
||||
DB_PATH_STR = str(DB_PATH)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _reset_event_loop() -> any:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.stop()
|
||||
loop.close()
|
||||
except RuntimeError:
|
||||
pass
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _teardown() -> any:
|
||||
yield
|
||||
DB_PATH.unlink()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _create_v1_db() -> None:
|
||||
db = "sqlite:///" + DB_PATH_STR
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Version(Base):
|
||||
__tablename__ = "version"
|
||||
version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
version: Mapped[int] = mapped_column(type_=INTEGER)
|
||||
|
||||
class Location(Base):
|
||||
__tablename__ = "location"
|
||||
people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
latitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
longitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
altitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
|
||||
engine = create_engine(db)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("PRAGMA user_version = 1"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _create_v2_db() -> None:
|
||||
db = "sqlite:///" + DB_PATH_STR
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Location(Base):
|
||||
__tablename__ = "location"
|
||||
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
latitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
longitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
altitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
|
||||
engine = create_engine(db)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("PRAGMA user_version = 2"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _create_latest_db() -> None:
|
||||
db = "sqlite:///" + DB_PATH_STR
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Location(Base):
|
||||
__tablename__ = "location"
|
||||
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
|
||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
|
||||
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
|
||||
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
|
||||
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
|
||||
|
||||
engine = create_engine(db)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_create_v1_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_migration_1_latest() -> None:
|
||||
nr_tables_ver_1 = 2
|
||||
table_ver_1_0 = "version"
|
||||
nr_column_ver_1_version = 2
|
||||
table_ver_1_1 = "location"
|
||||
nr_column_ver_1_location = 5
|
||||
nr_tables_ver_2 = 1
|
||||
table_ver_2_0 = "location"
|
||||
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == 1
|
||||
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
tables = sqlite3_cursor.fetchall()
|
||||
assert len(tables) == nr_tables_ver_1
|
||||
assert tables[0][0] == table_ver_1_0
|
||||
assert tables[1][0] == table_ver_1_1
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_0})")
|
||||
table_info_version = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_version) == nr_column_ver_1_version
|
||||
assert table_info_version[0] == (0, "version_type", "TEXT", 1, None, 1)
|
||||
assert table_info_version[1] == (1, "version", "INTEGER", 1, None, 0)
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_1})")
|
||||
table_info_location = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_location) == nr_column_ver_1_location
|
||||
assert table_info_location[0] == (0, "people", "TEXT", 1, None, 1)
|
||||
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
|
||||
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
asyncio.run(location_recorder.create_db_engine())
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
|
||||
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
tables = sqlite3_cursor.fetchall()
|
||||
assert len(tables) == nr_tables_ver_2
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
|
||||
table_info_location = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_location) == nr_column_ver_1_location
|
||||
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
|
||||
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_create_v2_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_migration_2_latest() -> None:
|
||||
nr_tables_ver_2 = 1
|
||||
table_ver_2_0 = "location"
|
||||
nr_column_ver_2_location = 5
|
||||
nr_tables_ver_3 = 1
|
||||
table_ver_3_0 = "location"
|
||||
nr_column_ver_3_location = 5
|
||||
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == 2 # noqa: PLR2004
|
||||
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
tables = sqlite3_cursor.fetchall()
|
||||
assert len(tables) == nr_tables_ver_2
|
||||
assert tables[0][0] == table_ver_2_0
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
|
||||
table_info_location = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_location) == nr_column_ver_2_location
|
||||
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
|
||||
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
|
||||
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
asyncio.run(location_recorder.create_db_engine())
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
|
||||
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
tables = sqlite3_cursor.fetchall()
|
||||
assert len(tables) == nr_tables_ver_3
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_3_0})")
|
||||
table_info_location = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_location) == nr_column_ver_3_location
|
||||
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
|
||||
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_create_db() -> None:
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
event_loop.run_until_complete(location_recorder.dispose_db_engine())
|
||||
assert DB_PATH.exists()
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_create_latest_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_inser_location_utc() -> None:
|
||||
latitude = 1.0
|
||||
longitude = 2.0
|
||||
altitude = 3.0
|
||||
person = "test_person"
|
||||
date_time = datetime.now(tz=UTC)
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
|
||||
event_loop.run_until_complete(
|
||||
location_recorder.insert_location(
|
||||
person=person,
|
||||
date_time=date_time,
|
||||
location=location_data,
|
||||
),
|
||||
)
|
||||
event_loop.run_until_complete(location_recorder.dispose_db_engine())
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("SELECT * FROM location")
|
||||
location = sqlite3_cursor.fetchone()
|
||||
assert location[0] == person
|
||||
assert location[1] == date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
assert location[2] == latitude
|
||||
assert location[3] == longitude
|
||||
assert location[4] == altitude
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_create_latest_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_inser_location_other() -> None:
|
||||
latitude = 1.0
|
||||
longitude = 2.0
|
||||
altitude = 3.0
|
||||
person = "test_person"
|
||||
tz = ZoneInfo("Asia/Shanghai")
|
||||
date_time = datetime.now(tz=tz)
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
|
||||
event_loop.run_until_complete(
|
||||
location_recorder.insert_location(
|
||||
person=person,
|
||||
date_time=date_time,
|
||||
location=location_data,
|
||||
),
|
||||
)
|
||||
event_loop.run_until_complete(location_recorder.dispose_db_engine())
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("SELECT * FROM location")
|
||||
location = sqlite3_cursor.fetchone()
|
||||
assert location[0] == person
|
||||
assert location[1] == date_time.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
assert location[2] == latitude
|
||||
assert location[3] == longitude
|
||||
assert location[4] == altitude
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_create_latest_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_insert_location_now() -> None:
|
||||
latitude = 1.0
|
||||
longitude = 2.0
|
||||
altitude = 3.0
|
||||
person = "test_person"
|
||||
date_time = datetime.now(tz=UTC)
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
|
||||
event_loop.run_until_complete(
|
||||
location_recorder.insert_location_now(
|
||||
person=person,
|
||||
location=location_data,
|
||||
),
|
||||
)
|
||||
event_loop.run_until_complete(location_recorder.dispose_db_engine())
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("SELECT * FROM location")
|
||||
location = sqlite3_cursor.fetchone()
|
||||
assert location[0] == person
|
||||
date_time_act = datetime.strptime(location[1], "%Y-%m-%dT%H:%M:%S%z")
|
||||
assert date_time.date() == date_time_act.date()
|
||||
assert location[2] == latitude
|
||||
assert location[3] == longitude
|
||||
assert location[4] == altitude
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_create_latest_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_insert_locations() -> None:
|
||||
locations: dict[datetime, LocationData] = {}
|
||||
person = "Tianyu"
|
||||
time_0 = datetime.now(tz=UTC)
|
||||
lat_0 = 1.0
|
||||
lon_0 = 2.0
|
||||
alt_0 = 3.0
|
||||
time_1 = datetime(2021, 8, 30, 10, 20, 15, tzinfo=UTC)
|
||||
lat_1 = 155.0
|
||||
lon_1 = 33.36
|
||||
alt_1 = 1058
|
||||
locations[time_0] = LocationData(lat_0, lon_0, alt_0)
|
||||
locations[time_1] = LocationData(lat_1, lon_1, alt_1)
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
event_loop.run_until_complete(
|
||||
location_recorder.insert_locations(person=person, locations=locations),
|
||||
)
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("SELECT * FROM location")
|
||||
locations = sqlite3_cursor.fetchall()
|
||||
assert len(locations) == 2 # noqa: PLR2004
|
||||
assert locations[0][0] == person
|
||||
assert locations[0][1] == time_0.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
assert locations[0][2] == lat_0
|
||||
assert locations[0][3] == lon_0
|
||||
assert locations[0][4] == alt_0
|
||||
assert locations[1][0] == person
|
||||
assert locations[1][1] == time_1.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
assert locations[1][2] == lat_1
|
||||
assert locations[1][3] == lon_1
|
||||
assert locations[1][4] == alt_1
|
||||
sqlite3_cursor.close()
|
||||
@@ -1,84 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from src.config import Config
|
||||
|
||||
|
||||
class TickTick:
|
||||
@dataclass
|
||||
class Task:
|
||||
projectId: str # noqa: N815
|
||||
title: str
|
||||
dueDate: str | None = None # noqa: N815
|
||||
content: str | None = None
|
||||
desc: str | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
print("Initializing TickTick...")
|
||||
if Config.get_env("TICKTICK_ACCESS_TOKEN") is None:
|
||||
self._begin_auth()
|
||||
else:
|
||||
self._access_token = Config.get_env("TICKTICK_ACCESS_TOKEN")
|
||||
|
||||
def _begin_auth(self) -> None:
|
||||
ticktick_code_auth_url = "https://ticktick.com/oauth/authorize?"
|
||||
ticktick_code_auth_params = {
|
||||
"client_id": Config.get_env("TICKTICK_CLIENT_ID"),
|
||||
"scope": "tasks:read tasks:write",
|
||||
"state": "begin_auth",
|
||||
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
|
||||
"response_type": "code",
|
||||
}
|
||||
ticktick_auth_url_encoded = urllib.parse.urlencode(ticktick_code_auth_params)
|
||||
print("Visit: ", ticktick_code_auth_url + ticktick_auth_url_encoded, " to authenticate.")
|
||||
|
||||
async def retrieve_access_token(self, code: str, state: str) -> bool:
|
||||
if state != "begin_auth":
|
||||
print("Invalid state.")
|
||||
return False
|
||||
ticktick_token_url = "https://ticktick.com/oauth/token" # noqa: S105
|
||||
ticktick_token_auth_params: dict[str, str] = {
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"scope": "tasks:write tasks:read",
|
||||
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
|
||||
}
|
||||
client_id = Config.get_env("TICKTICK_CLIENT_ID")
|
||||
client_secret = Config.get_env("TICKTICK_CLIENT_SECRET")
|
||||
response = await httpx.AsyncClient().post(
|
||||
ticktick_token_url,
|
||||
data=ticktick_token_auth_params,
|
||||
auth=httpx.BasicAuth(username=client_id, password=client_secret),
|
||||
timeout=10,
|
||||
)
|
||||
Config.update_env("TICKTICK_ACCESS_TOKEN", response.json()["access_token"])
|
||||
return True
|
||||
|
||||
async def get_tasks(self, project_id: str) -> list[dict]:
|
||||
ticktick_get_tasks_url = "https://api.ticktick.com/open/v1/project/" + project_id + "/data"
|
||||
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
|
||||
response = await httpx.AsyncClient().get(ticktick_get_tasks_url, headers=header, timeout=10)
|
||||
return response.json()["tasks"]
|
||||
|
||||
async def has_duplicate_task(self, project_id: str, task_title: str) -> bool:
|
||||
tasks = await self.get_tasks(project_id=project_id)
|
||||
return any(task["title"] == task_title for task in tasks)
|
||||
|
||||
async def create_task(self, task: TickTick.Task) -> dict[str, str]:
|
||||
if not await self.has_duplicate_task(project_id=task.projectId, task_title=task.title):
|
||||
ticktick_task_creation_url = "https://api.ticktick.com/open/v1/task"
|
||||
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
|
||||
await httpx.AsyncClient().post(ticktick_task_creation_url, headers=header, json=asdict(task), timeout=10)
|
||||
return {"title": task.title}
|
||||
|
||||
@staticmethod
|
||||
def datetime_to_ticktick_format(datetime: datetime) -> str:
|
||||
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "+0000"
|
||||
Reference in New Issue
Block a user