Start with go version

This commit is contained in:
2024-09-10 10:08:12 +02:00
parent af8b4db718
commit ea9c650f82
47 changed files with 272 additions and 1182 deletions

6
.gitignore vendored
View File

@@ -24,9 +24,9 @@ go.work.sum
# env file
.env
**temp_data/**
temp_data/
# py file for branch switching
*venv
*pytest_cache/**
*pycache*/**
__pycache__/
.pytest_cache/

8
.vscode/launch.json vendored
View File

@@ -10,6 +10,14 @@
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}"
},
{
"name": "Launch Poo Reverse",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/src/helper/poo_recorder_helper/main.go",
"args": ["reverse"]
}
]
}

Binary file not shown.

View File

@@ -1,78 +0,0 @@
from datetime import datetime
from pydantic import BaseModel
from src.config import Config
from src.util.homeassistant import HomeAssistant
from src.util.mqtt import MQTT
from src.util.notion import NotionAsync
class PooRecorder:
CONFIG_TOPIC = "homeassistant/text/poo_recorder/config"
AVAILABILITY_TOPIC = "studiotj/poo_recorder/status"
COMMAND_TOPIC = "studiotj/poo_recorder/update_text"
STATE_TOPIC = "studiotj/poo_recorder/text"
JSON_TOPIC = "studiotj/poo_recorder/attributes"
ONLINE = "online"
OFFLINE = "offline"
class RecordField(BaseModel):
status: str
latitude: str
longitude: str
def __init__(self, mqtt: MQTT, notion: NotionAsync, homeassistant: HomeAssistant) -> None:
print("Poo Recorder Initialization...")
self._notion = notion
self._table_id = Config.get_env("POO_RECORD_NOTION_TABLE_ID")
self._mqtt = mqtt
self._mqtt.publish(PooRecorder.CONFIG_TOPIC, PooRecorder.compose_config(), retain=True)
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
self._homeassistant = homeassistant
async def _note(self, now: datetime, status: str, latitude: str, longitude: str) -> None:
formatted_date = now.strftime("%Y-%m-%d")
formatted_time = now.strftime("%H:%M")
status.strip()
await self._notion.append_table_row_text(self._table_id, [formatted_date, formatted_time, status, latitude + "," + longitude])
async def record(self, record_detail: RecordField) -> None:
webhook_id: str = Config.get_env("HOMEASSISTANT_POO_TRIGGER_ID")
self._publish_text(record_detail.status)
now = datetime.now(tz=datetime.now().astimezone().tzinfo)
self._publish_time(now)
await self._note(now, record_detail.status, record_detail.latitude, record_detail.longitude)
await self._homeassistant.trigger_webhook(payload={"status": record_detail.status}, webhook_id=webhook_id)
def _publish_text(self, new_text: str) -> None:
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
self._mqtt.publish(PooRecorder.STATE_TOPIC, new_text, retain=True)
def _publish_time(self, time: datetime) -> None:
formatted_time = time.strftime("%a | %Y-%m-%d | %H:%M")
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
json_string = {"last_poo": formatted_time}
self._mqtt.publish(PooRecorder.JSON_TOPIC, json_string, retain=True)
@staticmethod
def compose_config() -> dict:
return {
"device": {
"name": "Dog Poop Recorder",
"model": "poop-recorder-backend",
"sw_version": Config.VERSION,
"identifiers": ["poo_recorder"],
"manufacturer": "Studio TJ",
},
"unique_id": "poo_recorder",
"name": "Poo Status",
"availability_topic": PooRecorder.AVAILABILITY_TOPIC,
"availability_template": "{{ value_json.availability }}",
"json_attributes_topic": PooRecorder.JSON_TOPIC,
"min": 0,
"max": 255,
"mode": "text",
"command_topic": PooRecorder.COMMAND_TOPIC,
"state_topic": PooRecorder.STATE_TOPIC,
}

View File

@@ -1,40 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar
from dotenv import dotenv_values, set_key, unset_key
if TYPE_CHECKING:
from collections import OrderedDict
config_path = Path(__file__).parent.resolve()
DOT_ENV_PATH = Path(config_path, ".env")
DOT_ENV_PATH.touch(mode=0o600, exist_ok=True)
class Config:
env_dict: ClassVar[OrderedDict[str, str]] = {}
dot_env_path = DOT_ENV_PATH
VERSION = "2.0"
@staticmethod
def init(dotenv_path: str = DOT_ENV_PATH) -> None:
Config.dot_env_path = dotenv_path
Config.env_dict = dotenv_values(dotenv_path=dotenv_path)
@staticmethod
def get_env(key: str) -> str | None:
if key in Config.env_dict:
return Config.env_dict[key]
return None
@staticmethod
def update_env(key: str, value: str) -> None:
set_key(Config.dot_env_path, key, value)
Config.env_dict = dotenv_values(dotenv_path=Config.dot_env_path)
@staticmethod
def remove_env(key: str) -> None:
unset_key(Config.dot_env_path, key)
Config.env_dict = dotenv_values(dotenv_path=Config.dot_env_path)

11
src/go.mod Normal file
View File

@@ -0,0 +1,11 @@
module github.com/t-liu93/home-automation-backend
go 1.23.0
require github.com/jomei/notionapi v1.13.2
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/cobra v1.8.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
)

12
src/go.sum Normal file
View File

@@ -0,0 +1,12 @@
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jomei/notionapi v1.13.2 h1:YpHKNpkoTMlUfWTlVIodOmQDgRKjfwmtSNVa6/6yC9E=
github.com/jomei/notionapi v1.13.2/go.mod h1:BqzP6JBddpBnXvMSIxiR5dCoCjKngmz5QNl1ONDlDoM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

View File

@@ -1,61 +0,0 @@
import argparse
import asyncio
import json
import sys
from datetime import datetime
from pathlib import Path
current_file_path = Path(__file__).resolve().parent
sys.path.append(str(current_file_path / ".." / ".." / ".."))
from src.util.location_recorder import LocationData, LocationRecorder # noqa: E402
# Create an argument parser
parser = argparse.ArgumentParser(description="Google Location Reader")
# Add an argument for the JSON file path
parser.add_argument("--json-file", type=str, help="Path to the JSON file")
# Parse the command-line arguments
args = parser.parse_args()
json_file_path: str = args.json_file
db_path = current_file_path / ".." / ".." / ".." / "temp_data" / "test.db"
location_recorder = LocationRecorder(db_path=str(db_path))
# Open the JSON file
with Path.open(json_file_path) as json_file:
data = json.load(json_file)
locations: list[dict] = data["locations"]
print(type(locations), len(locations))
async def insert() -> None:
nr_waypoints = 0
await location_recorder.create_db_engine()
locations_dict: dict[datetime, LocationData] = {}
for location in locations:
nr_waypoints += 1
try:
latitude: float = location["latitudeE7"] / 1e7
longitude: float = location["longitudeE7"] / 1e7
except KeyError:
continue
altitude: float = location.get("altitude", None)
try:
date_time = datetime.strptime(location["timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z")
except ValueError:
date_time = datetime.strptime(location["timestamp"], "%Y-%m-%dT%H:%M:%S%z")
locations_dict[date_time] = LocationData(
latitude=latitude,
longitude=longitude,
altitude=altitude,
)
await location_recorder.insert_locations("Tianyu", locations=locations_dict)
print(nr_waypoints)
await location_recorder.dispose_db_engine()
asyncio.run(insert())

View File

@@ -1,44 +0,0 @@
import argparse
import asyncio
import sys
from datetime import UTC
from pathlib import Path
import gpxpy
import gpxpy.gpx
current_file_path = Path(__file__).resolve().parent
sys.path.append(str(current_file_path / ".." / ".." / ".."))
from src.util.location_recorder import LocationData, LocationRecorder # noqa: E402
parser = argparse.ArgumentParser(description="GPX Location Reader")
parser.add_argument("--gpx-file", type=str, help="Path to the GPX file")
args = parser.parse_args()
gpx_location = args.gpx_file
gpx_file = Path.open(gpx_location)
gpx = gpxpy.parse(gpx_file)
db_path = current_file_path / ".." / ".." / ".." / "temp_data" / "test.db"
location_recorder = LocationRecorder(db_path=str(db_path))
async def iterate_and_insert() -> None:
nr_waypoints = 0
await location_recorder.create_db_engine()
for track in gpx.tracks:
for segment in track.segments:
for point in segment.points:
nr_waypoints += 1
print(f"Point at ({point.latitude},{point.longitude}) -> {point.time}")
point.time = point.time.replace(tzinfo=UTC)
location_data = LocationData(latitude=point.latitude, longitude=point.longitude, altitude=point.elevation)
await location_recorder.insert_location(person="Tianyu", date_time=point.time, location=location_data)
await location_recorder.dispose_db_engine()
print(nr_waypoints)
asyncio.run(iterate_and_insert())

View File

@@ -1,41 +0,0 @@
import asyncio
import datetime
from pathlib import Path
from src.config import Config
from src.util.notion import NotionAsync
Config.init()
notion = NotionAsync(token=Config.get_env("NOTION_TOKEN"))
current_file_path = Path(__file__).resolve()
current_dir = str(current_file_path.parent)
rows: list[str] = []
async def update_rows() -> None:
header: dict = await notion.get_block_children(block_id=Config.get_env("POO_RECORD_NOTION_TABLE_ID"), page_size=1)
header_id = header["results"][0]["id"]
with Path.open(current_dir + "/../../../temp_data/old_poo_record.txt") as file:
content = file.read()
rows = content.split("\n")
rows.reverse()
for row in rows:
t = row[0:5]
date = row[8:19]
formatted_date = datetime.datetime.strptime(date, "%a, %d %b").astimezone().replace(year=2024).strftime("%Y-%m-%d")
status = row[20:]
print(f"{formatted_date} {t} {status}")
await notion.append_table_row_text(
table_id=Config.get_env("POO_RECORD_NOTION_TABLE_ID"),
text_list=[formatted_date, t, status, "0,0"],
after=header_id,
)
asyncio.run(update_rows())

View File

@@ -0,0 +1,87 @@
/*
Copyright © 2024 Tianyu Liu
*/
package cmd
import (
"context"
"fmt"
"time"
"github.com/jomei/notionapi"
"github.com/spf13/cobra"
"github.com/t-liu93/home-automation-backend/util/notion"
)
var notionToken string
var notionTableId string
// reverseCmd represents the reverse command
var reverseCmd = &cobra.Command{
Use: "reverse",
Short: "Reverse given poo recording table",
Run: reverseRun,
}
func reverseRun(cmd *cobra.Command, args []string) {
rows := []notionapi.Block{}
fmt.Println("Reverse table ID: ", notionTableId)
notion.Init(notionToken)
headerBlock := <-notion.AGetBlockChildren(notionTableId, "", 100)
headerId := headerBlock.Children[0].GetID()
nextCursor := headerId.String()
hasMore := true
for hasMore {
blockChildren := <-notion.AGetBlockChildren(notionTableId, nextCursor, 100)
rows = append(rows, blockChildren.Children...)
hasMore = blockChildren.HasMore
nextCursor = blockChildren.NextCursor
}
rows = rows[1:]
rowsR := reverseTable(rows)
for _, row := range rowsR {
notion.GetClient().Block.Delete(context.Background(), row.GetID())
fmt.Println("Deleted row: ", row.GetID())
time.Sleep(400 * time.Millisecond)
}
after := headerId
for len(rowsR) > 0 {
var rowsToWrite []notionapi.Block
if len(rowsR) > 100 {
rowsToWrite = rowsR[:100]
} else {
rowsToWrite = rowsR
}
res, err := notion.GetClient().Block.AppendChildren(context.Background(), notionapi.BlockID(notionTableId), &notionapi.AppendBlockChildrenRequest{
After: after,
Children: rowsToWrite,
})
after = rowsToWrite[len(rowsToWrite)-1].GetID()
rowsR = rowsR[len(rowsToWrite):]
fmt.Println(res, err)
}
}
func reverseTable[T any](rows []T) []T {
for i, j := 0, len(rows)-1; i < j; i, j = i+1, j-1 {
rows[i], rows[j] = rows[j], rows[i]
}
return rows
}
func init() {
rootCmd.AddCommand(reverseCmd)
// Here you will define your flags and configuration settings.
// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// reverseCmd.PersistentFlags().String("foo", "", "A help for foo")
// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// reverseCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
reverseCmd.Flags().StringVar(&notionToken, "token", "", "Notion API token")
reverseCmd.Flags().StringVar(&notionTableId, "table-id", "", "Notion table id to reverse")
}

View File

@@ -0,0 +1,51 @@
/*
Copyright © 2024 Tianyu Liu
*/
package cmd
import (
"os"
"github.com/spf13/cobra"
)
// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "poo_recorder_helper",
Short: "A brief description of your application",
Long: `A longer description that spans multiple lines and likely contains
examples and usage of using your application. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.`,
// Uncomment the following line if your bare application
// has an action associated with it:
// Run: func(cmd *cobra.Command, args []string) { },
}
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}
func init() {
// Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here,
// will be global for your application.
// rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.poo_recorder_helper.yaml)")
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}

View File

@@ -0,0 +1,11 @@
/*
Copyright © 2024 Tianyu Liu
*/
package main
import "github.com/t-liu93/home-automation-backend/helper/poo_recorder_helper/cmd"
func main() {
cmd.Execute()
}

View File

@@ -1,60 +0,0 @@
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from src.components.poo_recorder import PooRecorder
from src.config import Config
from src.util.homeassistant import HomeAssistant
from src.util.location_recorder import LocationRecorder
from src.util.mqtt import MQTT
from src.util.notion import NotionAsync
from src.util.ticktick import TickTick
Config.init()
location_recorder_db = str(Path(__file__).resolve().parent / ".." / "location_recorder.db")
ticktick = TickTick()
notion = NotionAsync(token=Config.get_env(key="NOTION_TOKEN"))
mqtt = MQTT()
location_recorder = LocationRecorder(db_path=location_recorder_db)
homeassistant = HomeAssistant(ticktick=ticktick, location_recorder=location_recorder)
poo_recorder = PooRecorder(mqtt=mqtt, notion=notion, homeassistant=homeassistant)
@asynccontextmanager
async def _lifespan(_app: FastAPI): # noqa: ANN202
await mqtt.start()
await location_recorder.create_db_engine()
yield
await mqtt.stop()
await location_recorder.dispose_db_engine()
app = FastAPI(lifespan=_lifespan)
@app.get("/homeassistant/status")
async def get_status() -> dict:
return {"Status": "Ok"}
@app.post("/homeassistant/publish")
async def homeassistant_publish(payload: HomeAssistant.Message) -> dict:
return await homeassistant.process_message(message=payload)
# Poo recorder
@app.post("/poo/record")
async def record(record_detail: PooRecorder.RecordField) -> PooRecorder.RecordField:
await poo_recorder.record(record_detail)
return record_detail
# ticktick
@app.get("/ticktick/auth/code")
async def ticktick_auth(code: str, state: str) -> dict:
if await ticktick.retrieve_access_token(code, state):
return {"State": "Token Retrieved"}
return {"State": "Token Retrieval Failed"}

View File

View File

@@ -1,70 +0,0 @@
from collections import OrderedDict
from pathlib import Path
import pytest
from dotenv import dotenv_values, set_key
from src.config import Config
CONFIG_PATH = Path(__file__).parent.resolve()
TEST_DOT_ENV_PATH = Path(CONFIG_PATH, ".env_test")
EXPECTED_ENV_DICT: OrderedDict[str, str] = OrderedDict(
{
"KEY_1": "VALUE_1",
"KEY_2": "VALUE_2",
"NOTION_TOKEN": "1234454_234324",
},
)
@pytest.fixture
def _prepare_test_dot_env() -> any:
TEST_DOT_ENV_PATH.touch(mode=0o600, exist_ok=True)
for key, value in EXPECTED_ENV_DICT.items():
set_key(TEST_DOT_ENV_PATH, key, value)
yield
TEST_DOT_ENV_PATH.unlink()
@pytest.fixture
def _load_test_dot_env(_prepare_test_dot_env: any) -> None:
Config.init(dotenv_path=TEST_DOT_ENV_PATH)
@pytest.mark.usefixtures("_prepare_test_dot_env")
def test_init_config() -> None:
assert Config.env_dict == {}
Config.init(dotenv_path=TEST_DOT_ENV_PATH)
assert Config.env_dict == EXPECTED_ENV_DICT
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
assert dict_from_file == EXPECTED_ENV_DICT
@pytest.mark.usefixtures("_load_test_dot_env")
def test_get_config() -> None:
assert Config.get_env("NON_EXISTING_KEY") is None
key_1 = "KEY_1"
assert Config.get_env(key_1) == EXPECTED_ENV_DICT[key_1]
@pytest.mark.usefixtures("_load_test_dot_env")
def test_update_config() -> None:
key = "KEY_1"
value = EXPECTED_ENV_DICT[key]
new_value = "NEW_VALUE"
assert Config.get_env(key) == value
Config.update_env(key, new_value)
assert Config.get_env(key) == new_value
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
assert dict_from_file[key] == new_value
@pytest.mark.usefixtures("_load_test_dot_env")
def test_remove_config() -> None:
key = "KEY_1"
assert Config.get_env(key) == EXPECTED_ENV_DICT[key]
Config.remove_env(key)
assert Config.get_env(key) is None
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
assert key not in dict_from_file

View File

@@ -1 +0,0 @@

View File

View File

@@ -1,71 +0,0 @@
import ast
from datetime import datetime, timedelta, timezone
import httpx
from pydantic import BaseModel
from src.config import Config
from src.util.location_recorder import LocationData, LocationRecorder
from src.util.ticktick import TickTick
class HomeAssistant:
class Message(BaseModel):
target: str
action: str
content: str
def __init__(self, ticktick: TickTick, location_recorder: LocationRecorder) -> None:
self._ticktick = ticktick
self._location_recorder = location_recorder
async def process_message(self, message: Message) -> dict[str, str]:
if message.target == "ticktick":
return await self._process_ticktick_message(message=message)
if message.target == "location_recorder":
return await self._process_location(message=message)
return {"Status": "Unknown target"}
async def trigger_webhook(self, payload: dict[str, str], webhook_id: str) -> None:
token: str = Config.get_env("HOMEASSISTANT_TOKEN")
webhook_url: str = Config.get_env("HOMEASSISTANT_URL") + "/api/webhook/" + webhook_id
headers: dict[str, str] = {"Authorization": f"Bearer {token}"}
await httpx.AsyncClient().post(webhook_url, json=payload, headers=headers)
async def _process_ticktick_message(self, message: Message) -> dict[str, str]:
if message.action == "create_shopping_list":
return await self._create_shopping_list(content=message.content)
if message.action == "create_action_task":
return await self._create_action_task(content=message.content)
return {"Status": "Unknown action"}
async def _process_location(self, message: Message) -> dict[str, str]:
if message.action == "record":
location: dict[str, str] = ast.literal_eval(message.content)
await self._location_recorder.insert_location_now(
person=location["person"],
location=LocationData(
latitude=float(location["latitude"]),
longitude=float(location["longitude"]),
altitude=float(location["altitude"]),
),
)
return {"Status": "Location recorded"}
return {"Status": "Unknown action"}
async def _create_shopping_list(self, content: str) -> dict[str, str]:
project_id = Config.get_env("TICKTICK_SHOPPING_LIST")
item: dict[str, str] = ast.literal_eval(content)
task = TickTick.Task(projectId=project_id, title=item["item"])
return await self._ticktick.create_task(task=task)
async def _create_action_task(self, content: str) -> dict[str, str]:
detail: dict[str, str] = ast.literal_eval(content)
project_id = Config.get_env("TICKTICK_HOME_TASK_LIST")
due_hour = detail["due_hour"]
due = datetime.now(tz=datetime.now().astimezone().tzinfo) + timedelta(hours=due_hour)
due = (due + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
due = due.astimezone(timezone.utc)
task = TickTick.Task(projectId=project_id, title=detail["action"], dueDate=TickTick.datetime_to_ticktick_format(due))
return await self._ticktick.create_task(task=task)

View File

@@ -1,120 +0,0 @@
from dataclasses import dataclass
from datetime import UTC, datetime
from sqlalchemy import REAL, TEXT, insert, text
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
@dataclass
class LocationData:
latitude: float
longitude: float
altitude: float
class LocationRecorder:
USER_VERSION = 3
def __init__(self, db_path: str) -> None:
self._db_path = "sqlite+aiosqlite:///" + db_path
async def create_db_engine(self) -> None:
self._engine = create_async_engine(self._db_path)
async with self._engine.begin() as conn:
user_version = await self._get_user_version(conn=conn)
if user_version == 0:
await conn.run_sync(Base.metadata.create_all)
await self._set_user_version(conn=conn, user_version=2)
if user_version != LocationRecorder.USER_VERSION:
await self._migrate(conn=conn)
async def dispose_db_engine(self) -> None:
await self._engine.dispose()
async def insert_location(self, person: str, date_time: datetime, location: LocationData) -> None:
if date_time.tzinfo != UTC:
date_time = date_time.astimezone(UTC)
date_time_str = date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
async with self._engine.begin() as conn:
await conn.execute(
insert(Location)
.prefix_with("OR IGNORE")
.values(
person=person,
datetime=date_time_str,
latitude=location.latitude,
longitude=location.longitude,
altitude=location.altitude,
),
)
async def insert_locations(self, person: str, locations: dict[datetime, LocationData]) -> None:
async with self._engine.begin() as conn:
for k, v in locations.items():
dt = k
if k.tzinfo != UTC:
dt = k.astimezone(UTC)
date_time_str = dt.strftime("%Y-%m-%dT%H:%M:%S%z")
await conn.execute(
insert(Location)
.prefix_with("OR IGNORE")
.values(
person=person,
datetime=date_time_str,
latitude=v.latitude,
longitude=v.longitude,
altitude=v.altitude,
),
)
async def insert_location_now(self, person: str, location: LocationData) -> None:
now_utc = datetime.now(tz=UTC)
await self.insert_location(person, now_utc, location)
async def _get_user_version(self, conn: AsyncConnection) -> int:
return (await conn.execute(text("PRAGMA user_version"))).first()[0]
async def _set_user_version(self, conn: AsyncConnection, user_version: int) -> None:
await conn.execute(text("PRAGMA user_version = " + str(user_version)))
async def _migrate(self, conn: AsyncConnection) -> None:
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
if user_version == 1:
await self._migrate_1_2(conn=conn)
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
if user_version == 2: # noqa: PLR2004
await self._migrate_2_3(conn=conn)
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
async def _migrate_1_2(self, conn: AsyncConnection) -> None:
print("Location Recorder: migrate from db ver 1 to 2.")
await conn.execute(text("DROP TABLE version"))
await conn.execute(text("ALTER TABLE location RENAME people TO person"))
await self._set_user_version(conn=conn, user_version=2)
async def _migrate_2_3(self, conn: AsyncConnection) -> None:
print("Location Recorder: migrate from db ver 2 to 3.")
await conn.execute(text("ALTER TABLE location RENAME TO location_old"))
await conn.run_sync(Base.metadata.create_all)
await conn.execute(
text("""
INSERT INTO location (person, datetime, latitude, longitude, altitude)
SELECT person, datetime, latitude, longitude, altitude FROM location_old;
"""),
)
await conn.execute(text("DROP TABLE location_old"))
await self._set_user_version(conn=conn, user_version=3)

View File

@@ -1,73 +0,0 @@
import queue
from dataclasses import dataclass
from fastapi_mqtt import FastMQTT, MQTTConfig
@dataclass
class MQTTSubscription:
topic: str
callback: callable
subscribed: bool
@dataclass
class MQTTPendingMessage:
topic: str
payload: dict
retain: bool
class MQTT:
_instance = None
def __new__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
if not cls._instance:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self) -> None:
self._mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106
self._mqtt = FastMQTT(config=self._mqtt_config, client_id="home_automation_backend")
self._mqtt.mqtt_handlers.user_connect_handler = self.on_connect
self._mqtt.mqtt_handlers.user_message_handler = self.on_message
self._connected = False
self._subscribed_topic: dict[str, MQTTSubscription] = {}
self._queued_message: queue.Queue[MQTTPendingMessage] = queue.Queue()
async def start(self) -> None:
print("MQTT Starting...")
await self._mqtt.mqtt_startup()
async def stop(self) -> None:
print("MQTT Stopping...")
await self._mqtt.mqtt_shutdown()
def on_connect(self, client, flags, rc, properties) -> None: # noqa: ANN001, ARG002
print("Connected")
self._connected = True
while not self._queued_message.empty():
msg = self._queued_message.get(block=False)
self.publish(msg.topic, msg.payload, retain=msg.retain)
for topic, subscription in self._subscribed_topic.items():
if subscription.subscribed is False:
self.subscribe(topic, subscription.callback)
async def on_message(self, client, topic: str, payload: bytes, qos: int, properties: any) -> any: # noqa: ANN001, ARG002
print("On message")
if topic in self._subscribed_topic and self._subscribed_topic[topic].callback is not None:
await self._subscribed_topic[topic].callback(payload)
def subscribe(self, topic: str, callback: callable) -> None:
if self._connected:
print("Subscribe to topic: ", topic)
self._mqtt.client.subscribe(topic)
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=True)
else:
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=False)
def publish(self, topic: str, payload: dict, *, retain: bool) -> None:
if self._connected:
self._mqtt.publish(topic, payload=payload, retain=retain)
else:
self._queued_message.put(MQTTPendingMessage(topic, payload, retain=retain))

View File

@@ -1,82 +0,0 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from notion_client import AsyncClient as Client
@dataclass
class Text:
content: str
@dataclass
class RichText:
type: str
href: str | None = None
@dataclass
class RichTextText(RichText):
type: str = "text"
text: Text = field(default_factory=lambda: Text(content=""))
class NotionAsync:
def __init__(self, token: str) -> None:
self._client = Client(auth=token)
def update_token(self, token: str) -> None:
self._client.aclose()
self._client = Client(auth=token)
async def get_block(self, block_id: str) -> dict:
return await self._client.blocks.retrieve(block_id=block_id)
async def get_block_children(self, block_id: str, start_cursor: str | None = None, page_size: int = 100) -> dict:
return await self._client.blocks.children.list(
block_id=block_id,
start_cursor=start_cursor,
page_size=page_size,
)
async def block_is_table(self, block_id: str) -> bool:
block: dict = await self.get_block(block_id=block_id)
return block["type"] == "table"
async def get_table_width(self, table_id: str) -> int:
table = await self._client.blocks.retrieve(block_id=table_id)
return table["table"]["table_width"]
async def append_table_row_text(self, table_id: str, text_list: list[str], after: str | None = None) -> None:
cells: list[RichText] = []
for content in text_list:
cells.append([asdict(RichTextText(text=Text(content)))]) # noqa: PERF401
await self.append_table_row(table_id=table_id, cells=cells, after=after)
async def append_table_row(self, table_id: str, cells: list[RichText], after: str | None = None) -> None:
if not await self.block_is_table(table_id):
return
table_width = await self.get_table_width(table_id=table_id)
if table_width != len(cells):
return
children = [
{
"object": "block",
"type": "table_row",
"table_row": {
"cells": cells,
},
},
]
if after is None:
await self._client.blocks.children.append(
block_id=table_id,
children=children,
)
else:
await self._client.blocks.children.append(
block_id=table_id,
children=children,
after=after,
)

62
src/util/notion/notion.go Normal file
View File

@@ -0,0 +1,62 @@
package notion
import (
"context"
"github.com/jomei/notionapi"
)
var client *notionapi.Client
var authToken string
func Init(token string) {
authToken = token
client = notionapi.NewClient(notionapi.Token(token))
}
func GetClient() *notionapi.Client {
return client
}
func AGetBlock(blockId string) chan struct {
Block notionapi.Block
Error error
} {
retval := make(chan struct {
Block notionapi.Block
Error error
})
go func() {
block, err := client.Block.Get(context.Background(), notionapi.BlockID(blockId))
retval <- struct {
Block notionapi.Block
Error error
}{block, err}
}()
return retval
}
func AGetBlockChildren(blockId string, startCursor string, pageSize int) chan struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
} {
retval := make(chan struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
})
pagination := notionapi.Pagination{StartCursor: notionapi.Cursor(startCursor), PageSize: pageSize}
go func() {
children, err := client.Block.GetChildren(context.Background(), notionapi.BlockID(blockId), &pagination)
retval <- struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
}{children.Results, children.NextCursor, children.HasMore, err}
}()
return retval
}

View File

@@ -0,0 +1,27 @@
package notion_test
import (
"testing"
"github.com/t-liu93/home-automation-backend/util/notion"
)
func TestGetBlockBasic(t *testing.T) {
notion.Init("")
block := <-notion.AGetBlock("")
actObj := block.Block.GetObject()
expObj := "block"
if string(actObj) != expObj {
t.Errorf("Expected %s, but got %s", expObj, actObj)
}
}
func TestGetBlockChildrenBasic(t *testing.T) {
notion.Init("")
blockChildren := <-notion.AGetBlockChildren("", "", 100)
actLen := len(blockChildren.Children)
expLen := 100
if actLen != expLen {
t.Errorf("Expected %d, but got %d", expLen, actLen)
}
}

View File

@@ -1,354 +0,0 @@
import asyncio
import sqlite3
from datetime import UTC, datetime
from pathlib import Path
from zoneinfo import ZoneInfo
import pytest
from sqlalchemy import INTEGER, REAL, TEXT, create_engine, text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from src.util.location_recorder import LocationData, LocationRecorder
DB_PATH = Path(__file__).resolve().parent / "test.db"
DB_PATH_STR = str(DB_PATH)
@pytest.fixture
def _reset_event_loop() -> any:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.stop()
loop.close()
except RuntimeError:
pass
asyncio.set_event_loop(asyncio.new_event_loop())
@pytest.fixture
def _teardown() -> any:
yield
DB_PATH.unlink()
@pytest.fixture
def _create_v1_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Version(Base):
__tablename__ = "version"
version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
version: Mapped[int] = mapped_column(type_=INTEGER)
class Location(Base):
__tablename__ = "location"
people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
latitude: Mapped[float] = mapped_column(type_=REAL)
longitude: Mapped[float] = mapped_column(type_=REAL)
altitude: Mapped[float] = mapped_column(type_=REAL)
engine = create_engine(db)
Base.metadata.create_all(engine)
with engine.begin() as conn:
conn.execute(text("PRAGMA user_version = 1"))
@pytest.fixture
def _create_v2_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
latitude: Mapped[float] = mapped_column(type_=REAL)
longitude: Mapped[float] = mapped_column(type_=REAL)
altitude: Mapped[float] = mapped_column(type_=REAL)
engine = create_engine(db)
Base.metadata.create_all(engine)
with engine.begin() as conn:
conn.execute(text("PRAGMA user_version = 2"))
@pytest.fixture
def _create_latest_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
engine = create_engine(db)
Base.metadata.create_all(engine)
@pytest.mark.usefixtures("_create_v1_db")
@pytest.mark.usefixtures("_teardown")
def test_migration_1_latest() -> None:
nr_tables_ver_1 = 2
table_ver_1_0 = "version"
nr_column_ver_1_version = 2
table_ver_1_1 = "location"
nr_column_ver_1_location = 5
nr_tables_ver_2 = 1
table_ver_2_0 = "location"
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == 1
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_1
assert tables[0][0] == table_ver_1_0
assert tables[1][0] == table_ver_1_1
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_0})")
table_info_version = sqlite3_cursor.fetchall()
assert len(table_info_version) == nr_column_ver_1_version
assert table_info_version[0] == (0, "version_type", "TEXT", 1, None, 1)
assert table_info_version[1] == (1, "version", "INTEGER", 1, None, 0)
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_1})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_1_location
assert table_info_location[0] == (0, "people", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
asyncio.run(location_recorder.create_db_engine())
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_2
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_1_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
sqlite3_cursor.close()
@pytest.mark.usefixtures("_create_v2_db")
@pytest.mark.usefixtures("_teardown")
def test_migration_2_latest() -> None:
nr_tables_ver_2 = 1
table_ver_2_0 = "location"
nr_column_ver_2_location = 5
nr_tables_ver_3 = 1
table_ver_3_0 = "location"
nr_column_ver_3_location = 5
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == 2 # noqa: PLR2004
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_2
assert tables[0][0] == table_ver_2_0
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_2_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
asyncio.run(location_recorder.create_db_engine())
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_3
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_3_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_3_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_teardown")
def test_create_db() -> None:
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
event_loop.run_until_complete(location_recorder.dispose_db_engine())
assert DB_PATH.exists()
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_inser_location_utc() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
date_time = datetime.now(tz=UTC)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location(
person=person,
date_time=date_time,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
assert location[1] == date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_inser_location_other() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
tz = ZoneInfo("Asia/Shanghai")
date_time = datetime.now(tz=tz)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location(
person=person,
date_time=date_time,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
assert location[1] == date_time.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_insert_location_now() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
date_time = datetime.now(tz=UTC)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location_now(
person=person,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
date_time_act = datetime.strptime(location[1], "%Y-%m-%dT%H:%M:%S%z")
assert date_time.date() == date_time_act.date()
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_insert_locations() -> None:
locations: dict[datetime, LocationData] = {}
person = "Tianyu"
time_0 = datetime.now(tz=UTC)
lat_0 = 1.0
lon_0 = 2.0
alt_0 = 3.0
time_1 = datetime(2021, 8, 30, 10, 20, 15, tzinfo=UTC)
lat_1 = 155.0
lon_1 = 33.36
alt_1 = 1058
locations[time_0] = LocationData(lat_0, lon_0, alt_0)
locations[time_1] = LocationData(lat_1, lon_1, alt_1)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
event_loop.run_until_complete(
location_recorder.insert_locations(person=person, locations=locations),
)
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
locations = sqlite3_cursor.fetchall()
assert len(locations) == 2 # noqa: PLR2004
assert locations[0][0] == person
assert locations[0][1] == time_0.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert locations[0][2] == lat_0
assert locations[0][3] == lon_0
assert locations[0][4] == alt_0
assert locations[1][0] == person
assert locations[1][1] == time_1.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert locations[1][2] == lat_1
assert locations[1][3] == lon_1
assert locations[1][4] == alt_1
sqlite3_cursor.close()

View File

@@ -1,84 +0,0 @@
from __future__ import annotations
import urllib.parse
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from datetime import datetime
import httpx
from src.config import Config
class TickTick:
@dataclass
class Task:
projectId: str # noqa: N815
title: str
dueDate: str | None = None # noqa: N815
content: str | None = None
desc: str | None = None
def __init__(self) -> None:
print("Initializing TickTick...")
if Config.get_env("TICKTICK_ACCESS_TOKEN") is None:
self._begin_auth()
else:
self._access_token = Config.get_env("TICKTICK_ACCESS_TOKEN")
def _begin_auth(self) -> None:
ticktick_code_auth_url = "https://ticktick.com/oauth/authorize?"
ticktick_code_auth_params = {
"client_id": Config.get_env("TICKTICK_CLIENT_ID"),
"scope": "tasks:read tasks:write",
"state": "begin_auth",
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
"response_type": "code",
}
ticktick_auth_url_encoded = urllib.parse.urlencode(ticktick_code_auth_params)
print("Visit: ", ticktick_code_auth_url + ticktick_auth_url_encoded, " to authenticate.")
async def retrieve_access_token(self, code: str, state: str) -> bool:
if state != "begin_auth":
print("Invalid state.")
return False
ticktick_token_url = "https://ticktick.com/oauth/token" # noqa: S105
ticktick_token_auth_params: dict[str, str] = {
"code": code,
"grant_type": "authorization_code",
"scope": "tasks:write tasks:read",
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
}
client_id = Config.get_env("TICKTICK_CLIENT_ID")
client_secret = Config.get_env("TICKTICK_CLIENT_SECRET")
response = await httpx.AsyncClient().post(
ticktick_token_url,
data=ticktick_token_auth_params,
auth=httpx.BasicAuth(username=client_id, password=client_secret),
timeout=10,
)
Config.update_env("TICKTICK_ACCESS_TOKEN", response.json()["access_token"])
return True
async def get_tasks(self, project_id: str) -> list[dict]:
ticktick_get_tasks_url = "https://api.ticktick.com/open/v1/project/" + project_id + "/data"
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
response = await httpx.AsyncClient().get(ticktick_get_tasks_url, headers=header, timeout=10)
return response.json()["tasks"]
async def has_duplicate_task(self, project_id: str, task_title: str) -> bool:
tasks = await self.get_tasks(project_id=project_id)
return any(task["title"] == task_title for task in tasks)
async def create_task(self, task: TickTick.Task) -> dict[str, str]:
if not await self.has_duplicate_task(project_id=task.projectId, task_title=task.title):
ticktick_task_creation_url = "https://api.ticktick.com/open/v1/task"
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
await httpx.AsyncClient().post(ticktick_task_creation_url, headers=header, json=asdict(task), timeout=10)
return {"title": task.title}
@staticmethod
def datetime_to_ticktick_format(datetime: datetime) -> str:
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "+0000"