Start with go version

This commit is contained in:
2024-09-10 10:08:12 +02:00
parent af8b4db718
commit ea9c650f82
47 changed files with 272 additions and 1182 deletions

View File

View File

@@ -1,71 +0,0 @@
import ast
from datetime import datetime, timedelta, timezone
import httpx
from pydantic import BaseModel
from src.config import Config
from src.util.location_recorder import LocationData, LocationRecorder
from src.util.ticktick import TickTick
class HomeAssistant:
class Message(BaseModel):
target: str
action: str
content: str
def __init__(self, ticktick: TickTick, location_recorder: LocationRecorder) -> None:
self._ticktick = ticktick
self._location_recorder = location_recorder
async def process_message(self, message: Message) -> dict[str, str]:
if message.target == "ticktick":
return await self._process_ticktick_message(message=message)
if message.target == "location_recorder":
return await self._process_location(message=message)
return {"Status": "Unknown target"}
async def trigger_webhook(self, payload: dict[str, str], webhook_id: str) -> None:
token: str = Config.get_env("HOMEASSISTANT_TOKEN")
webhook_url: str = Config.get_env("HOMEASSISTANT_URL") + "/api/webhook/" + webhook_id
headers: dict[str, str] = {"Authorization": f"Bearer {token}"}
await httpx.AsyncClient().post(webhook_url, json=payload, headers=headers)
async def _process_ticktick_message(self, message: Message) -> dict[str, str]:
if message.action == "create_shopping_list":
return await self._create_shopping_list(content=message.content)
if message.action == "create_action_task":
return await self._create_action_task(content=message.content)
return {"Status": "Unknown action"}
async def _process_location(self, message: Message) -> dict[str, str]:
if message.action == "record":
location: dict[str, str] = ast.literal_eval(message.content)
await self._location_recorder.insert_location_now(
person=location["person"],
location=LocationData(
latitude=float(location["latitude"]),
longitude=float(location["longitude"]),
altitude=float(location["altitude"]),
),
)
return {"Status": "Location recorded"}
return {"Status": "Unknown action"}
async def _create_shopping_list(self, content: str) -> dict[str, str]:
project_id = Config.get_env("TICKTICK_SHOPPING_LIST")
item: dict[str, str] = ast.literal_eval(content)
task = TickTick.Task(projectId=project_id, title=item["item"])
return await self._ticktick.create_task(task=task)
async def _create_action_task(self, content: str) -> dict[str, str]:
detail: dict[str, str] = ast.literal_eval(content)
project_id = Config.get_env("TICKTICK_HOME_TASK_LIST")
due_hour = detail["due_hour"]
due = datetime.now(tz=datetime.now().astimezone().tzinfo) + timedelta(hours=due_hour)
due = (due + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
due = due.astimezone(timezone.utc)
task = TickTick.Task(projectId=project_id, title=detail["action"], dueDate=TickTick.datetime_to_ticktick_format(due))
return await self._ticktick.create_task(task=task)

View File

@@ -1,120 +0,0 @@
from dataclasses import dataclass
from datetime import UTC, datetime
from sqlalchemy import REAL, TEXT, insert, text
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
@dataclass
class LocationData:
latitude: float
longitude: float
altitude: float
class LocationRecorder:
USER_VERSION = 3
def __init__(self, db_path: str) -> None:
self._db_path = "sqlite+aiosqlite:///" + db_path
async def create_db_engine(self) -> None:
self._engine = create_async_engine(self._db_path)
async with self._engine.begin() as conn:
user_version = await self._get_user_version(conn=conn)
if user_version == 0:
await conn.run_sync(Base.metadata.create_all)
await self._set_user_version(conn=conn, user_version=2)
if user_version != LocationRecorder.USER_VERSION:
await self._migrate(conn=conn)
async def dispose_db_engine(self) -> None:
await self._engine.dispose()
async def insert_location(self, person: str, date_time: datetime, location: LocationData) -> None:
if date_time.tzinfo != UTC:
date_time = date_time.astimezone(UTC)
date_time_str = date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
async with self._engine.begin() as conn:
await conn.execute(
insert(Location)
.prefix_with("OR IGNORE")
.values(
person=person,
datetime=date_time_str,
latitude=location.latitude,
longitude=location.longitude,
altitude=location.altitude,
),
)
async def insert_locations(self, person: str, locations: dict[datetime, LocationData]) -> None:
async with self._engine.begin() as conn:
for k, v in locations.items():
dt = k
if k.tzinfo != UTC:
dt = k.astimezone(UTC)
date_time_str = dt.strftime("%Y-%m-%dT%H:%M:%S%z")
await conn.execute(
insert(Location)
.prefix_with("OR IGNORE")
.values(
person=person,
datetime=date_time_str,
latitude=v.latitude,
longitude=v.longitude,
altitude=v.altitude,
),
)
async def insert_location_now(self, person: str, location: LocationData) -> None:
now_utc = datetime.now(tz=UTC)
await self.insert_location(person, now_utc, location)
async def _get_user_version(self, conn: AsyncConnection) -> int:
return (await conn.execute(text("PRAGMA user_version"))).first()[0]
async def _set_user_version(self, conn: AsyncConnection, user_version: int) -> None:
await conn.execute(text("PRAGMA user_version = " + str(user_version)))
async def _migrate(self, conn: AsyncConnection) -> None:
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
if user_version == 1:
await self._migrate_1_2(conn=conn)
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
if user_version == 2: # noqa: PLR2004
await self._migrate_2_3(conn=conn)
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
async def _migrate_1_2(self, conn: AsyncConnection) -> None:
print("Location Recorder: migrate from db ver 1 to 2.")
await conn.execute(text("DROP TABLE version"))
await conn.execute(text("ALTER TABLE location RENAME people TO person"))
await self._set_user_version(conn=conn, user_version=2)
async def _migrate_2_3(self, conn: AsyncConnection) -> None:
print("Location Recorder: migrate from db ver 2 to 3.")
await conn.execute(text("ALTER TABLE location RENAME TO location_old"))
await conn.run_sync(Base.metadata.create_all)
await conn.execute(
text("""
INSERT INTO location (person, datetime, latitude, longitude, altitude)
SELECT person, datetime, latitude, longitude, altitude FROM location_old;
"""),
)
await conn.execute(text("DROP TABLE location_old"))
await self._set_user_version(conn=conn, user_version=3)

View File

@@ -1,73 +0,0 @@
import queue
from dataclasses import dataclass
from fastapi_mqtt import FastMQTT, MQTTConfig
@dataclass
class MQTTSubscription:
topic: str
callback: callable
subscribed: bool
@dataclass
class MQTTPendingMessage:
topic: str
payload: dict
retain: bool
class MQTT:
_instance = None
def __new__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
if not cls._instance:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self) -> None:
self._mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106
self._mqtt = FastMQTT(config=self._mqtt_config, client_id="home_automation_backend")
self._mqtt.mqtt_handlers.user_connect_handler = self.on_connect
self._mqtt.mqtt_handlers.user_message_handler = self.on_message
self._connected = False
self._subscribed_topic: dict[str, MQTTSubscription] = {}
self._queued_message: queue.Queue[MQTTPendingMessage] = queue.Queue()
async def start(self) -> None:
print("MQTT Starting...")
await self._mqtt.mqtt_startup()
async def stop(self) -> None:
print("MQTT Stopping...")
await self._mqtt.mqtt_shutdown()
def on_connect(self, client, flags, rc, properties) -> None: # noqa: ANN001, ARG002
print("Connected")
self._connected = True
while not self._queued_message.empty():
msg = self._queued_message.get(block=False)
self.publish(msg.topic, msg.payload, retain=msg.retain)
for topic, subscription in self._subscribed_topic.items():
if subscription.subscribed is False:
self.subscribe(topic, subscription.callback)
async def on_message(self, client, topic: str, payload: bytes, qos: int, properties: any) -> any: # noqa: ANN001, ARG002
print("On message")
if topic in self._subscribed_topic and self._subscribed_topic[topic].callback is not None:
await self._subscribed_topic[topic].callback(payload)
def subscribe(self, topic: str, callback: callable) -> None:
if self._connected:
print("Subscribe to topic: ", topic)
self._mqtt.client.subscribe(topic)
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=True)
else:
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=False)
def publish(self, topic: str, payload: dict, *, retain: bool) -> None:
if self._connected:
self._mqtt.publish(topic, payload=payload, retain=retain)
else:
self._queued_message.put(MQTTPendingMessage(topic, payload, retain=retain))

View File

@@ -1,82 +0,0 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from notion_client import AsyncClient as Client
@dataclass
class Text:
content: str
@dataclass
class RichText:
type: str
href: str | None = None
@dataclass
class RichTextText(RichText):
type: str = "text"
text: Text = field(default_factory=lambda: Text(content=""))
class NotionAsync:
def __init__(self, token: str) -> None:
self._client = Client(auth=token)
def update_token(self, token: str) -> None:
self._client.aclose()
self._client = Client(auth=token)
async def get_block(self, block_id: str) -> dict:
return await self._client.blocks.retrieve(block_id=block_id)
async def get_block_children(self, block_id: str, start_cursor: str | None = None, page_size: int = 100) -> dict:
return await self._client.blocks.children.list(
block_id=block_id,
start_cursor=start_cursor,
page_size=page_size,
)
async def block_is_table(self, block_id: str) -> bool:
block: dict = await self.get_block(block_id=block_id)
return block["type"] == "table"
async def get_table_width(self, table_id: str) -> int:
table = await self._client.blocks.retrieve(block_id=table_id)
return table["table"]["table_width"]
async def append_table_row_text(self, table_id: str, text_list: list[str], after: str | None = None) -> None:
cells: list[RichText] = []
for content in text_list:
cells.append([asdict(RichTextText(text=Text(content)))]) # noqa: PERF401
await self.append_table_row(table_id=table_id, cells=cells, after=after)
async def append_table_row(self, table_id: str, cells: list[RichText], after: str | None = None) -> None:
if not await self.block_is_table(table_id):
return
table_width = await self.get_table_width(table_id=table_id)
if table_width != len(cells):
return
children = [
{
"object": "block",
"type": "table_row",
"table_row": {
"cells": cells,
},
},
]
if after is None:
await self._client.blocks.children.append(
block_id=table_id,
children=children,
)
else:
await self._client.blocks.children.append(
block_id=table_id,
children=children,
after=after,
)

62
src/util/notion/notion.go Normal file
View File

@@ -0,0 +1,62 @@
package notion
import (
"context"
"github.com/jomei/notionapi"
)
var client *notionapi.Client
var authToken string
func Init(token string) {
authToken = token
client = notionapi.NewClient(notionapi.Token(token))
}
func GetClient() *notionapi.Client {
return client
}
func AGetBlock(blockId string) chan struct {
Block notionapi.Block
Error error
} {
retval := make(chan struct {
Block notionapi.Block
Error error
})
go func() {
block, err := client.Block.Get(context.Background(), notionapi.BlockID(blockId))
retval <- struct {
Block notionapi.Block
Error error
}{block, err}
}()
return retval
}
func AGetBlockChildren(blockId string, startCursor string, pageSize int) chan struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
} {
retval := make(chan struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
})
pagination := notionapi.Pagination{StartCursor: notionapi.Cursor(startCursor), PageSize: pageSize}
go func() {
children, err := client.Block.GetChildren(context.Background(), notionapi.BlockID(blockId), &pagination)
retval <- struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
}{children.Results, children.NextCursor, children.HasMore, err}
}()
return retval
}

View File

@@ -0,0 +1,27 @@
package notion_test
import (
"testing"
"github.com/t-liu93/home-automation-backend/util/notion"
)
func TestGetBlockBasic(t *testing.T) {
notion.Init("")
block := <-notion.AGetBlock("")
actObj := block.Block.GetObject()
expObj := "block"
if string(actObj) != expObj {
t.Errorf("Expected %s, but got %s", expObj, actObj)
}
}
func TestGetBlockChildrenBasic(t *testing.T) {
notion.Init("")
blockChildren := <-notion.AGetBlockChildren("", "", 100)
actLen := len(blockChildren.Children)
expLen := 100
if actLen != expLen {
t.Errorf("Expected %d, but got %d", expLen, actLen)
}
}

View File

@@ -1,354 +0,0 @@
import asyncio
import sqlite3
from datetime import UTC, datetime
from pathlib import Path
from zoneinfo import ZoneInfo
import pytest
from sqlalchemy import INTEGER, REAL, TEXT, create_engine, text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from src.util.location_recorder import LocationData, LocationRecorder
DB_PATH = Path(__file__).resolve().parent / "test.db"
DB_PATH_STR = str(DB_PATH)
@pytest.fixture
def _reset_event_loop() -> any:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.stop()
loop.close()
except RuntimeError:
pass
asyncio.set_event_loop(asyncio.new_event_loop())
@pytest.fixture
def _teardown() -> any:
yield
DB_PATH.unlink()
@pytest.fixture
def _create_v1_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Version(Base):
__tablename__ = "version"
version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
version: Mapped[int] = mapped_column(type_=INTEGER)
class Location(Base):
__tablename__ = "location"
people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
latitude: Mapped[float] = mapped_column(type_=REAL)
longitude: Mapped[float] = mapped_column(type_=REAL)
altitude: Mapped[float] = mapped_column(type_=REAL)
engine = create_engine(db)
Base.metadata.create_all(engine)
with engine.begin() as conn:
conn.execute(text("PRAGMA user_version = 1"))
@pytest.fixture
def _create_v2_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
latitude: Mapped[float] = mapped_column(type_=REAL)
longitude: Mapped[float] = mapped_column(type_=REAL)
altitude: Mapped[float] = mapped_column(type_=REAL)
engine = create_engine(db)
Base.metadata.create_all(engine)
with engine.begin() as conn:
conn.execute(text("PRAGMA user_version = 2"))
@pytest.fixture
def _create_latest_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
engine = create_engine(db)
Base.metadata.create_all(engine)
@pytest.mark.usefixtures("_create_v1_db")
@pytest.mark.usefixtures("_teardown")
def test_migration_1_latest() -> None:
nr_tables_ver_1 = 2
table_ver_1_0 = "version"
nr_column_ver_1_version = 2
table_ver_1_1 = "location"
nr_column_ver_1_location = 5
nr_tables_ver_2 = 1
table_ver_2_0 = "location"
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == 1
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_1
assert tables[0][0] == table_ver_1_0
assert tables[1][0] == table_ver_1_1
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_0})")
table_info_version = sqlite3_cursor.fetchall()
assert len(table_info_version) == nr_column_ver_1_version
assert table_info_version[0] == (0, "version_type", "TEXT", 1, None, 1)
assert table_info_version[1] == (1, "version", "INTEGER", 1, None, 0)
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_1})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_1_location
assert table_info_location[0] == (0, "people", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
asyncio.run(location_recorder.create_db_engine())
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_2
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_1_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
sqlite3_cursor.close()
@pytest.mark.usefixtures("_create_v2_db")
@pytest.mark.usefixtures("_teardown")
def test_migration_2_latest() -> None:
nr_tables_ver_2 = 1
table_ver_2_0 = "location"
nr_column_ver_2_location = 5
nr_tables_ver_3 = 1
table_ver_3_0 = "location"
nr_column_ver_3_location = 5
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == 2 # noqa: PLR2004
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_2
assert tables[0][0] == table_ver_2_0
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_2_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
asyncio.run(location_recorder.create_db_engine())
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_3
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_3_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_3_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_teardown")
def test_create_db() -> None:
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
event_loop.run_until_complete(location_recorder.dispose_db_engine())
assert DB_PATH.exists()
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_inser_location_utc() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
date_time = datetime.now(tz=UTC)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location(
person=person,
date_time=date_time,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
assert location[1] == date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_inser_location_other() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
tz = ZoneInfo("Asia/Shanghai")
date_time = datetime.now(tz=tz)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location(
person=person,
date_time=date_time,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
assert location[1] == date_time.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_insert_location_now() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
date_time = datetime.now(tz=UTC)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location_now(
person=person,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
date_time_act = datetime.strptime(location[1], "%Y-%m-%dT%H:%M:%S%z")
assert date_time.date() == date_time_act.date()
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_insert_locations() -> None:
locations: dict[datetime, LocationData] = {}
person = "Tianyu"
time_0 = datetime.now(tz=UTC)
lat_0 = 1.0
lon_0 = 2.0
alt_0 = 3.0
time_1 = datetime(2021, 8, 30, 10, 20, 15, tzinfo=UTC)
lat_1 = 155.0
lon_1 = 33.36
alt_1 = 1058
locations[time_0] = LocationData(lat_0, lon_0, alt_0)
locations[time_1] = LocationData(lat_1, lon_1, alt_1)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
event_loop.run_until_complete(
location_recorder.insert_locations(person=person, locations=locations),
)
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
locations = sqlite3_cursor.fetchall()
assert len(locations) == 2 # noqa: PLR2004
assert locations[0][0] == person
assert locations[0][1] == time_0.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert locations[0][2] == lat_0
assert locations[0][3] == lon_0
assert locations[0][4] == alt_0
assert locations[1][0] == person
assert locations[1][1] == time_1.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert locations[1][2] == lat_1
assert locations[1][3] == lon_1
assert locations[1][4] == alt_1
sqlite3_cursor.close()

View File

@@ -1,84 +0,0 @@
from __future__ import annotations
import urllib.parse
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from datetime import datetime
import httpx
from src.config import Config
class TickTick:
@dataclass
class Task:
projectId: str # noqa: N815
title: str
dueDate: str | None = None # noqa: N815
content: str | None = None
desc: str | None = None
def __init__(self) -> None:
print("Initializing TickTick...")
if Config.get_env("TICKTICK_ACCESS_TOKEN") is None:
self._begin_auth()
else:
self._access_token = Config.get_env("TICKTICK_ACCESS_TOKEN")
def _begin_auth(self) -> None:
ticktick_code_auth_url = "https://ticktick.com/oauth/authorize?"
ticktick_code_auth_params = {
"client_id": Config.get_env("TICKTICK_CLIENT_ID"),
"scope": "tasks:read tasks:write",
"state": "begin_auth",
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
"response_type": "code",
}
ticktick_auth_url_encoded = urllib.parse.urlencode(ticktick_code_auth_params)
print("Visit: ", ticktick_code_auth_url + ticktick_auth_url_encoded, " to authenticate.")
async def retrieve_access_token(self, code: str, state: str) -> bool:
if state != "begin_auth":
print("Invalid state.")
return False
ticktick_token_url = "https://ticktick.com/oauth/token" # noqa: S105
ticktick_token_auth_params: dict[str, str] = {
"code": code,
"grant_type": "authorization_code",
"scope": "tasks:write tasks:read",
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
}
client_id = Config.get_env("TICKTICK_CLIENT_ID")
client_secret = Config.get_env("TICKTICK_CLIENT_SECRET")
response = await httpx.AsyncClient().post(
ticktick_token_url,
data=ticktick_token_auth_params,
auth=httpx.BasicAuth(username=client_id, password=client_secret),
timeout=10,
)
Config.update_env("TICKTICK_ACCESS_TOKEN", response.json()["access_token"])
return True
async def get_tasks(self, project_id: str) -> list[dict]:
ticktick_get_tasks_url = "https://api.ticktick.com/open/v1/project/" + project_id + "/data"
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
response = await httpx.AsyncClient().get(ticktick_get_tasks_url, headers=header, timeout=10)
return response.json()["tasks"]
async def has_duplicate_task(self, project_id: str, task_title: str) -> bool:
tasks = await self.get_tasks(project_id=project_id)
return any(task["title"] == task_title for task in tasks)
async def create_task(self, task: TickTick.Task) -> dict[str, str]:
if not await self.has_duplicate_task(project_id=task.projectId, task_title=task.title):
ticktick_task_creation_url = "https://api.ticktick.com/open/v1/task"
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
await httpx.AsyncClient().post(ticktick_task_creation_url, headers=header, json=asdict(task), timeout=10)
return {"title": task.title}
@staticmethod
def datetime_to_ticktick_format(datetime: datetime) -> str:
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "+0000"