Start with go version

This commit is contained in:
2024-09-10 10:08:12 +02:00
parent af8b4db718
commit ea9c650f82
47 changed files with 272 additions and 1182 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
View File
Binary file not shown.
-78
View File
@@ -1,78 +0,0 @@
from datetime import datetime
from pydantic import BaseModel
from src.config import Config
from src.util.homeassistant import HomeAssistant
from src.util.mqtt import MQTT
from src.util.notion import NotionAsync
class PooRecorder:
CONFIG_TOPIC = "homeassistant/text/poo_recorder/config"
AVAILABILITY_TOPIC = "studiotj/poo_recorder/status"
COMMAND_TOPIC = "studiotj/poo_recorder/update_text"
STATE_TOPIC = "studiotj/poo_recorder/text"
JSON_TOPIC = "studiotj/poo_recorder/attributes"
ONLINE = "online"
OFFLINE = "offline"
class RecordField(BaseModel):
status: str
latitude: str
longitude: str
def __init__(self, mqtt: MQTT, notion: NotionAsync, homeassistant: HomeAssistant) -> None:
print("Poo Recorder Initialization...")
self._notion = notion
self._table_id = Config.get_env("POO_RECORD_NOTION_TABLE_ID")
self._mqtt = mqtt
self._mqtt.publish(PooRecorder.CONFIG_TOPIC, PooRecorder.compose_config(), retain=True)
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
self._homeassistant = homeassistant
async def _note(self, now: datetime, status: str, latitude: str, longitude: str) -> None:
formatted_date = now.strftime("%Y-%m-%d")
formatted_time = now.strftime("%H:%M")
status.strip()
await self._notion.append_table_row_text(self._table_id, [formatted_date, formatted_time, status, latitude + "," + longitude])
async def record(self, record_detail: RecordField) -> None:
webhook_id: str = Config.get_env("HOMEASSISTANT_POO_TRIGGER_ID")
self._publish_text(record_detail.status)
now = datetime.now(tz=datetime.now().astimezone().tzinfo)
self._publish_time(now)
await self._note(now, record_detail.status, record_detail.latitude, record_detail.longitude)
await self._homeassistant.trigger_webhook(payload={"status": record_detail.status}, webhook_id=webhook_id)
def _publish_text(self, new_text: str) -> None:
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
self._mqtt.publish(PooRecorder.STATE_TOPIC, new_text, retain=True)
def _publish_time(self, time: datetime) -> None:
formatted_time = time.strftime("%a | %Y-%m-%d | %H:%M")
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
json_string = {"last_poo": formatted_time}
self._mqtt.publish(PooRecorder.JSON_TOPIC, json_string, retain=True)
@staticmethod
def compose_config() -> dict:
return {
"device": {
"name": "Dog Poop Recorder",
"model": "poop-recorder-backend",
"sw_version": Config.VERSION,
"identifiers": ["poo_recorder"],
"manufacturer": "Studio TJ",
},
"unique_id": "poo_recorder",
"name": "Poo Status",
"availability_topic": PooRecorder.AVAILABILITY_TOPIC,
"availability_template": "{{ value_json.availability }}",
"json_attributes_topic": PooRecorder.JSON_TOPIC,
"min": 0,
"max": 255,
"mode": "text",
"command_topic": PooRecorder.COMMAND_TOPIC,
"state_topic": PooRecorder.STATE_TOPIC,
}
-40
View File
@@ -1,40 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar
from dotenv import dotenv_values, set_key, unset_key
if TYPE_CHECKING:
from collections import OrderedDict
config_path = Path(__file__).parent.resolve()
DOT_ENV_PATH = Path(config_path, ".env")
DOT_ENV_PATH.touch(mode=0o600, exist_ok=True)
class Config:
env_dict: ClassVar[OrderedDict[str, str]] = {}
dot_env_path = DOT_ENV_PATH
VERSION = "2.0"
@staticmethod
def init(dotenv_path: str = DOT_ENV_PATH) -> None:
Config.dot_env_path = dotenv_path
Config.env_dict = dotenv_values(dotenv_path=dotenv_path)
@staticmethod
def get_env(key: str) -> str | None:
if key in Config.env_dict:
return Config.env_dict[key]
return None
@staticmethod
def update_env(key: str, value: str) -> None:
set_key(Config.dot_env_path, key, value)
Config.env_dict = dotenv_values(dotenv_path=Config.dot_env_path)
@staticmethod
def remove_env(key: str) -> None:
unset_key(Config.dot_env_path, key)
Config.env_dict = dotenv_values(dotenv_path=Config.dot_env_path)
+11
View File
@@ -0,0 +1,11 @@
module github.com/t-liu93/home-automation-backend
go 1.23.0
require github.com/jomei/notionapi v1.13.2
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/cobra v1.8.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
)
+12
View File
@@ -0,0 +1,12 @@
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jomei/notionapi v1.13.2 h1:YpHKNpkoTMlUfWTlVIodOmQDgRKjfwmtSNVa6/6yC9E=
github.com/jomei/notionapi v1.13.2/go.mod h1:BqzP6JBddpBnXvMSIxiR5dCoCjKngmz5QNl1ONDlDoM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
View File
@@ -1,61 +0,0 @@
import argparse
import asyncio
import json
import sys
from datetime import datetime
from pathlib import Path
current_file_path = Path(__file__).resolve().parent
sys.path.append(str(current_file_path / ".." / ".." / ".."))
from src.util.location_recorder import LocationData, LocationRecorder # noqa: E402
# Create an argument parser
parser = argparse.ArgumentParser(description="Google Location Reader")
# Add an argument for the JSON file path
parser.add_argument("--json-file", type=str, help="Path to the JSON file")
# Parse the command-line arguments
args = parser.parse_args()
json_file_path: str = args.json_file
db_path = current_file_path / ".." / ".." / ".." / "temp_data" / "test.db"
location_recorder = LocationRecorder(db_path=str(db_path))
# Open the JSON file
with Path.open(json_file_path) as json_file:
data = json.load(json_file)
locations: list[dict] = data["locations"]
print(type(locations), len(locations))
async def insert() -> None:
nr_waypoints = 0
await location_recorder.create_db_engine()
locations_dict: dict[datetime, LocationData] = {}
for location in locations:
nr_waypoints += 1
try:
latitude: float = location["latitudeE7"] / 1e7
longitude: float = location["longitudeE7"] / 1e7
except KeyError:
continue
altitude: float = location.get("altitude", None)
try:
date_time = datetime.strptime(location["timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z")
except ValueError:
date_time = datetime.strptime(location["timestamp"], "%Y-%m-%dT%H:%M:%S%z")
locations_dict[date_time] = LocationData(
latitude=latitude,
longitude=longitude,
altitude=altitude,
)
await location_recorder.insert_locations("Tianyu", locations=locations_dict)
print(nr_waypoints)
await location_recorder.dispose_db_engine()
asyncio.run(insert())
@@ -1,44 +0,0 @@
import argparse
import asyncio
import sys
from datetime import UTC
from pathlib import Path
import gpxpy
import gpxpy.gpx
current_file_path = Path(__file__).resolve().parent
sys.path.append(str(current_file_path / ".." / ".." / ".."))
from src.util.location_recorder import LocationData, LocationRecorder # noqa: E402
parser = argparse.ArgumentParser(description="GPX Location Reader")
parser.add_argument("--gpx-file", type=str, help="Path to the GPX file")
args = parser.parse_args()
gpx_location = args.gpx_file
gpx_file = Path.open(gpx_location)
gpx = gpxpy.parse(gpx_file)
db_path = current_file_path / ".." / ".." / ".." / "temp_data" / "test.db"
location_recorder = LocationRecorder(db_path=str(db_path))
async def iterate_and_insert() -> None:
nr_waypoints = 0
await location_recorder.create_db_engine()
for track in gpx.tracks:
for segment in track.segments:
for point in segment.points:
nr_waypoints += 1
print(f"Point at ({point.latitude},{point.longitude}) -> {point.time}")
point.time = point.time.replace(tzinfo=UTC)
location_data = LocationData(latitude=point.latitude, longitude=point.longitude, altitude=point.elevation)
await location_recorder.insert_location(person="Tianyu", date_time=point.time, location=location_data)
await location_recorder.dispose_db_engine()
print(nr_waypoints)
asyncio.run(iterate_and_insert())
View File
@@ -1,41 +0,0 @@
import asyncio
import datetime
from pathlib import Path
from src.config import Config
from src.util.notion import NotionAsync
Config.init()
notion = NotionAsync(token=Config.get_env("NOTION_TOKEN"))
current_file_path = Path(__file__).resolve()
current_dir = str(current_file_path.parent)
rows: list[str] = []
async def update_rows() -> None:
header: dict = await notion.get_block_children(block_id=Config.get_env("POO_RECORD_NOTION_TABLE_ID"), page_size=1)
header_id = header["results"][0]["id"]
with Path.open(current_dir + "/../../../temp_data/old_poo_record.txt") as file:
content = file.read()
rows = content.split("\n")
rows.reverse()
for row in rows:
t = row[0:5]
date = row[8:19]
formatted_date = datetime.datetime.strptime(date, "%a, %d %b").astimezone().replace(year=2024).strftime("%Y-%m-%d")
status = row[20:]
print(f"{formatted_date} {t} {status}")
await notion.append_table_row_text(
table_id=Config.get_env("POO_RECORD_NOTION_TABLE_ID"),
text_list=[formatted_date, t, status, "0,0"],
after=header_id,
)
asyncio.run(update_rows())
@@ -0,0 +1,87 @@
/*
Copyright © 2024 Tianyu Liu
*/
package cmd
import (
"context"
"fmt"
"time"
"github.com/jomei/notionapi"
"github.com/spf13/cobra"
"github.com/t-liu93/home-automation-backend/util/notion"
)
var notionToken string
var notionTableId string
// reverseCmd represents the reverse command
var reverseCmd = &cobra.Command{
Use: "reverse",
Short: "Reverse given poo recording table",
Run: reverseRun,
}
func reverseRun(cmd *cobra.Command, args []string) {
rows := []notionapi.Block{}
fmt.Println("Reverse table ID: ", notionTableId)
notion.Init(notionToken)
headerBlock := <-notion.AGetBlockChildren(notionTableId, "", 100)
headerId := headerBlock.Children[0].GetID()
nextCursor := headerId.String()
hasMore := true
for hasMore {
blockChildren := <-notion.AGetBlockChildren(notionTableId, nextCursor, 100)
rows = append(rows, blockChildren.Children...)
hasMore = blockChildren.HasMore
nextCursor = blockChildren.NextCursor
}
rows = rows[1:]
rowsR := reverseTable(rows)
for _, row := range rowsR {
notion.GetClient().Block.Delete(context.Background(), row.GetID())
fmt.Println("Deleted row: ", row.GetID())
time.Sleep(400 * time.Millisecond)
}
after := headerId
for len(rowsR) > 0 {
var rowsToWrite []notionapi.Block
if len(rowsR) > 100 {
rowsToWrite = rowsR[:100]
} else {
rowsToWrite = rowsR
}
res, err := notion.GetClient().Block.AppendChildren(context.Background(), notionapi.BlockID(notionTableId), &notionapi.AppendBlockChildrenRequest{
After: after,
Children: rowsToWrite,
})
after = rowsToWrite[len(rowsToWrite)-1].GetID()
rowsR = rowsR[len(rowsToWrite):]
fmt.Println(res, err)
}
}
func reverseTable[T any](rows []T) []T {
for i, j := 0, len(rows)-1; i < j; i, j = i+1, j-1 {
rows[i], rows[j] = rows[j], rows[i]
}
return rows
}
func init() {
rootCmd.AddCommand(reverseCmd)
// Here you will define your flags and configuration settings.
// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// reverseCmd.PersistentFlags().String("foo", "", "A help for foo")
// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// reverseCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
reverseCmd.Flags().StringVar(&notionToken, "token", "", "Notion API token")
reverseCmd.Flags().StringVar(&notionTableId, "table-id", "", "Notion table id to reverse")
}
@@ -0,0 +1,51 @@
/*
Copyright © 2024 Tianyu Liu
*/
package cmd
import (
"os"
"github.com/spf13/cobra"
)
// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "poo_recorder_helper",
Short: "A brief description of your application",
Long: `A longer description that spans multiple lines and likely contains
examples and usage of using your application. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.`,
// Uncomment the following line if your bare application
// has an action associated with it:
// Run: func(cmd *cobra.Command, args []string) { },
}
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}
func init() {
// Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here,
// will be global for your application.
// rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.poo_recorder_helper.yaml)")
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
+11
View File
@@ -0,0 +1,11 @@
/*
Copyright © 2024 Tianyu Liu
*/
package main
import "github.com/t-liu93/home-automation-backend/helper/poo_recorder_helper/cmd"
func main() {
cmd.Execute()
}
-60
View File
@@ -1,60 +0,0 @@
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from src.components.poo_recorder import PooRecorder
from src.config import Config
from src.util.homeassistant import HomeAssistant
from src.util.location_recorder import LocationRecorder
from src.util.mqtt import MQTT
from src.util.notion import NotionAsync
from src.util.ticktick import TickTick
Config.init()
location_recorder_db = str(Path(__file__).resolve().parent / ".." / "location_recorder.db")
ticktick = TickTick()
notion = NotionAsync(token=Config.get_env(key="NOTION_TOKEN"))
mqtt = MQTT()
location_recorder = LocationRecorder(db_path=location_recorder_db)
homeassistant = HomeAssistant(ticktick=ticktick, location_recorder=location_recorder)
poo_recorder = PooRecorder(mqtt=mqtt, notion=notion, homeassistant=homeassistant)
@asynccontextmanager
async def _lifespan(_app: FastAPI): # noqa: ANN202
await mqtt.start()
await location_recorder.create_db_engine()
yield
await mqtt.stop()
await location_recorder.dispose_db_engine()
app = FastAPI(lifespan=_lifespan)
@app.get("/homeassistant/status")
async def get_status() -> dict:
return {"Status": "Ok"}
@app.post("/homeassistant/publish")
async def homeassistant_publish(payload: HomeAssistant.Message) -> dict:
return await homeassistant.process_message(message=payload)
# Poo recorder
@app.post("/poo/record")
async def record(record_detail: PooRecorder.RecordField) -> PooRecorder.RecordField:
await poo_recorder.record(record_detail)
return record_detail
# ticktick
@app.get("/ticktick/auth/code")
async def ticktick_auth(code: str, state: str) -> dict:
if await ticktick.retrieve_access_token(code, state):
return {"State": "Token Retrieved"}
return {"State": "Token Retrieval Failed"}
View File
Binary file not shown.
-70
View File
@@ -1,70 +0,0 @@
from collections import OrderedDict
from pathlib import Path
import pytest
from dotenv import dotenv_values, set_key
from src.config import Config
CONFIG_PATH = Path(__file__).parent.resolve()
TEST_DOT_ENV_PATH = Path(CONFIG_PATH, ".env_test")
EXPECTED_ENV_DICT: OrderedDict[str, str] = OrderedDict(
{
"KEY_1": "VALUE_1",
"KEY_2": "VALUE_2",
"NOTION_TOKEN": "1234454_234324",
},
)
@pytest.fixture
def _prepare_test_dot_env() -> any:
TEST_DOT_ENV_PATH.touch(mode=0o600, exist_ok=True)
for key, value in EXPECTED_ENV_DICT.items():
set_key(TEST_DOT_ENV_PATH, key, value)
yield
TEST_DOT_ENV_PATH.unlink()
@pytest.fixture
def _load_test_dot_env(_prepare_test_dot_env: any) -> None:
Config.init(dotenv_path=TEST_DOT_ENV_PATH)
@pytest.mark.usefixtures("_prepare_test_dot_env")
def test_init_config() -> None:
assert Config.env_dict == {}
Config.init(dotenv_path=TEST_DOT_ENV_PATH)
assert Config.env_dict == EXPECTED_ENV_DICT
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
assert dict_from_file == EXPECTED_ENV_DICT
@pytest.mark.usefixtures("_load_test_dot_env")
def test_get_config() -> None:
assert Config.get_env("NON_EXISTING_KEY") is None
key_1 = "KEY_1"
assert Config.get_env(key_1) == EXPECTED_ENV_DICT[key_1]
@pytest.mark.usefixtures("_load_test_dot_env")
def test_update_config() -> None:
key = "KEY_1"
value = EXPECTED_ENV_DICT[key]
new_value = "NEW_VALUE"
assert Config.get_env(key) == value
Config.update_env(key, new_value)
assert Config.get_env(key) == new_value
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
assert dict_from_file[key] == new_value
@pytest.mark.usefixtures("_load_test_dot_env")
def test_remove_config() -> None:
key = "KEY_1"
assert Config.get_env(key) == EXPECTED_ENV_DICT[key]
Config.remove_env(key)
assert Config.get_env(key) is None
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
assert key not in dict_from_file
-1
View File
@@ -1 +0,0 @@
View File
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
-71
View File
@@ -1,71 +0,0 @@
import ast
from datetime import datetime, timedelta, timezone
import httpx
from pydantic import BaseModel
from src.config import Config
from src.util.location_recorder import LocationData, LocationRecorder
from src.util.ticktick import TickTick
class HomeAssistant:
class Message(BaseModel):
target: str
action: str
content: str
def __init__(self, ticktick: TickTick, location_recorder: LocationRecorder) -> None:
self._ticktick = ticktick
self._location_recorder = location_recorder
async def process_message(self, message: Message) -> dict[str, str]:
if message.target == "ticktick":
return await self._process_ticktick_message(message=message)
if message.target == "location_recorder":
return await self._process_location(message=message)
return {"Status": "Unknown target"}
async def trigger_webhook(self, payload: dict[str, str], webhook_id: str) -> None:
token: str = Config.get_env("HOMEASSISTANT_TOKEN")
webhook_url: str = Config.get_env("HOMEASSISTANT_URL") + "/api/webhook/" + webhook_id
headers: dict[str, str] = {"Authorization": f"Bearer {token}"}
await httpx.AsyncClient().post(webhook_url, json=payload, headers=headers)
async def _process_ticktick_message(self, message: Message) -> dict[str, str]:
if message.action == "create_shopping_list":
return await self._create_shopping_list(content=message.content)
if message.action == "create_action_task":
return await self._create_action_task(content=message.content)
return {"Status": "Unknown action"}
async def _process_location(self, message: Message) -> dict[str, str]:
if message.action == "record":
location: dict[str, str] = ast.literal_eval(message.content)
await self._location_recorder.insert_location_now(
person=location["person"],
location=LocationData(
latitude=float(location["latitude"]),
longitude=float(location["longitude"]),
altitude=float(location["altitude"]),
),
)
return {"Status": "Location recorded"}
return {"Status": "Unknown action"}
async def _create_shopping_list(self, content: str) -> dict[str, str]:
project_id = Config.get_env("TICKTICK_SHOPPING_LIST")
item: dict[str, str] = ast.literal_eval(content)
task = TickTick.Task(projectId=project_id, title=item["item"])
return await self._ticktick.create_task(task=task)
async def _create_action_task(self, content: str) -> dict[str, str]:
detail: dict[str, str] = ast.literal_eval(content)
project_id = Config.get_env("TICKTICK_HOME_TASK_LIST")
due_hour = detail["due_hour"]
due = datetime.now(tz=datetime.now().astimezone().tzinfo) + timedelta(hours=due_hour)
due = (due + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
due = due.astimezone(timezone.utc)
task = TickTick.Task(projectId=project_id, title=detail["action"], dueDate=TickTick.datetime_to_ticktick_format(due))
return await self._ticktick.create_task(task=task)
-120
View File
@@ -1,120 +0,0 @@
from dataclasses import dataclass
from datetime import UTC, datetime
from sqlalchemy import REAL, TEXT, insert, text
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
@dataclass
class LocationData:
latitude: float
longitude: float
altitude: float
class LocationRecorder:
USER_VERSION = 3
def __init__(self, db_path: str) -> None:
self._db_path = "sqlite+aiosqlite:///" + db_path
async def create_db_engine(self) -> None:
self._engine = create_async_engine(self._db_path)
async with self._engine.begin() as conn:
user_version = await self._get_user_version(conn=conn)
if user_version == 0:
await conn.run_sync(Base.metadata.create_all)
await self._set_user_version(conn=conn, user_version=2)
if user_version != LocationRecorder.USER_VERSION:
await self._migrate(conn=conn)
async def dispose_db_engine(self) -> None:
await self._engine.dispose()
async def insert_location(self, person: str, date_time: datetime, location: LocationData) -> None:
if date_time.tzinfo != UTC:
date_time = date_time.astimezone(UTC)
date_time_str = date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
async with self._engine.begin() as conn:
await conn.execute(
insert(Location)
.prefix_with("OR IGNORE")
.values(
person=person,
datetime=date_time_str,
latitude=location.latitude,
longitude=location.longitude,
altitude=location.altitude,
),
)
async def insert_locations(self, person: str, locations: dict[datetime, LocationData]) -> None:
async with self._engine.begin() as conn:
for k, v in locations.items():
dt = k
if k.tzinfo != UTC:
dt = k.astimezone(UTC)
date_time_str = dt.strftime("%Y-%m-%dT%H:%M:%S%z")
await conn.execute(
insert(Location)
.prefix_with("OR IGNORE")
.values(
person=person,
datetime=date_time_str,
latitude=v.latitude,
longitude=v.longitude,
altitude=v.altitude,
),
)
async def insert_location_now(self, person: str, location: LocationData) -> None:
now_utc = datetime.now(tz=UTC)
await self.insert_location(person, now_utc, location)
async def _get_user_version(self, conn: AsyncConnection) -> int:
return (await conn.execute(text("PRAGMA user_version"))).first()[0]
async def _set_user_version(self, conn: AsyncConnection, user_version: int) -> None:
await conn.execute(text("PRAGMA user_version = " + str(user_version)))
async def _migrate(self, conn: AsyncConnection) -> None:
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
if user_version == 1:
await self._migrate_1_2(conn=conn)
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
if user_version == 2: # noqa: PLR2004
await self._migrate_2_3(conn=conn)
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
async def _migrate_1_2(self, conn: AsyncConnection) -> None:
print("Location Recorder: migrate from db ver 1 to 2.")
await conn.execute(text("DROP TABLE version"))
await conn.execute(text("ALTER TABLE location RENAME people TO person"))
await self._set_user_version(conn=conn, user_version=2)
async def _migrate_2_3(self, conn: AsyncConnection) -> None:
print("Location Recorder: migrate from db ver 2 to 3.")
await conn.execute(text("ALTER TABLE location RENAME TO location_old"))
await conn.run_sync(Base.metadata.create_all)
await conn.execute(
text("""
INSERT INTO location (person, datetime, latitude, longitude, altitude)
SELECT person, datetime, latitude, longitude, altitude FROM location_old;
"""),
)
await conn.execute(text("DROP TABLE location_old"))
await self._set_user_version(conn=conn, user_version=3)
-73
View File
@@ -1,73 +0,0 @@
import queue
from dataclasses import dataclass
from fastapi_mqtt import FastMQTT, MQTTConfig
@dataclass
class MQTTSubscription:
topic: str
callback: callable
subscribed: bool
@dataclass
class MQTTPendingMessage:
topic: str
payload: dict
retain: bool
class MQTT:
_instance = None
def __new__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
if not cls._instance:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self) -> None:
self._mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106
self._mqtt = FastMQTT(config=self._mqtt_config, client_id="home_automation_backend")
self._mqtt.mqtt_handlers.user_connect_handler = self.on_connect
self._mqtt.mqtt_handlers.user_message_handler = self.on_message
self._connected = False
self._subscribed_topic: dict[str, MQTTSubscription] = {}
self._queued_message: queue.Queue[MQTTPendingMessage] = queue.Queue()
async def start(self) -> None:
print("MQTT Starting...")
await self._mqtt.mqtt_startup()
async def stop(self) -> None:
print("MQTT Stopping...")
await self._mqtt.mqtt_shutdown()
def on_connect(self, client, flags, rc, properties) -> None: # noqa: ANN001, ARG002
print("Connected")
self._connected = True
while not self._queued_message.empty():
msg = self._queued_message.get(block=False)
self.publish(msg.topic, msg.payload, retain=msg.retain)
for topic, subscription in self._subscribed_topic.items():
if subscription.subscribed is False:
self.subscribe(topic, subscription.callback)
async def on_message(self, client, topic: str, payload: bytes, qos: int, properties: any) -> any: # noqa: ANN001, ARG002
print("On message")
if topic in self._subscribed_topic and self._subscribed_topic[topic].callback is not None:
await self._subscribed_topic[topic].callback(payload)
def subscribe(self, topic: str, callback: callable) -> None:
if self._connected:
print("Subscribe to topic: ", topic)
self._mqtt.client.subscribe(topic)
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=True)
else:
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=False)
def publish(self, topic: str, payload: dict, *, retain: bool) -> None:
if self._connected:
self._mqtt.publish(topic, payload=payload, retain=retain)
else:
self._queued_message.put(MQTTPendingMessage(topic, payload, retain=retain))
-82
View File
@@ -1,82 +0,0 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from notion_client import AsyncClient as Client
@dataclass
class Text:
content: str
@dataclass
class RichText:
type: str
href: str | None = None
@dataclass
class RichTextText(RichText):
type: str = "text"
text: Text = field(default_factory=lambda: Text(content=""))
class NotionAsync:
def __init__(self, token: str) -> None:
self._client = Client(auth=token)
def update_token(self, token: str) -> None:
self._client.aclose()
self._client = Client(auth=token)
async def get_block(self, block_id: str) -> dict:
return await self._client.blocks.retrieve(block_id=block_id)
async def get_block_children(self, block_id: str, start_cursor: str | None = None, page_size: int = 100) -> dict:
return await self._client.blocks.children.list(
block_id=block_id,
start_cursor=start_cursor,
page_size=page_size,
)
async def block_is_table(self, block_id: str) -> bool:
block: dict = await self.get_block(block_id=block_id)
return block["type"] == "table"
async def get_table_width(self, table_id: str) -> int:
table = await self._client.blocks.retrieve(block_id=table_id)
return table["table"]["table_width"]
async def append_table_row_text(self, table_id: str, text_list: list[str], after: str | None = None) -> None:
cells: list[RichText] = []
for content in text_list:
cells.append([asdict(RichTextText(text=Text(content)))]) # noqa: PERF401
await self.append_table_row(table_id=table_id, cells=cells, after=after)
async def append_table_row(self, table_id: str, cells: list[RichText], after: str | None = None) -> None:
if not await self.block_is_table(table_id):
return
table_width = await self.get_table_width(table_id=table_id)
if table_width != len(cells):
return
children = [
{
"object": "block",
"type": "table_row",
"table_row": {
"cells": cells,
},
},
]
if after is None:
await self._client.blocks.children.append(
block_id=table_id,
children=children,
)
else:
await self._client.blocks.children.append(
block_id=table_id,
children=children,
after=after,
)
+62
View File
@@ -0,0 +1,62 @@
package notion
import (
"context"
"github.com/jomei/notionapi"
)
var client *notionapi.Client
var authToken string
func Init(token string) {
authToken = token
client = notionapi.NewClient(notionapi.Token(token))
}
func GetClient() *notionapi.Client {
return client
}
func AGetBlock(blockId string) chan struct {
Block notionapi.Block
Error error
} {
retval := make(chan struct {
Block notionapi.Block
Error error
})
go func() {
block, err := client.Block.Get(context.Background(), notionapi.BlockID(blockId))
retval <- struct {
Block notionapi.Block
Error error
}{block, err}
}()
return retval
}
func AGetBlockChildren(blockId string, startCursor string, pageSize int) chan struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
} {
retval := make(chan struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
})
pagination := notionapi.Pagination{StartCursor: notionapi.Cursor(startCursor), PageSize: pageSize}
go func() {
children, err := client.Block.GetChildren(context.Background(), notionapi.BlockID(blockId), &pagination)
retval <- struct {
Children []notionapi.Block
NextCursor string
HasMore bool
Error error
}{children.Results, children.NextCursor, children.HasMore, err}
}()
return retval
}
+27
View File
@@ -0,0 +1,27 @@
package notion_test
import (
"testing"
"github.com/t-liu93/home-automation-backend/util/notion"
)
func TestGetBlockBasic(t *testing.T) {
notion.Init("")
block := <-notion.AGetBlock("")
actObj := block.Block.GetObject()
expObj := "block"
if string(actObj) != expObj {
t.Errorf("Expected %s, but got %s", expObj, actObj)
}
}
func TestGetBlockChildrenBasic(t *testing.T) {
notion.Init("")
blockChildren := <-notion.AGetBlockChildren("", "", 100)
actLen := len(blockChildren.Children)
expLen := 100
if actLen != expLen {
t.Errorf("Expected %d, but got %d", expLen, actLen)
}
}
View File
Binary file not shown.
-354
View File
@@ -1,354 +0,0 @@
import asyncio
import sqlite3
from datetime import UTC, datetime
from pathlib import Path
from zoneinfo import ZoneInfo
import pytest
from sqlalchemy import INTEGER, REAL, TEXT, create_engine, text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from src.util.location_recorder import LocationData, LocationRecorder
DB_PATH = Path(__file__).resolve().parent / "test.db"
DB_PATH_STR = str(DB_PATH)
@pytest.fixture
def _reset_event_loop() -> any:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.stop()
loop.close()
except RuntimeError:
pass
asyncio.set_event_loop(asyncio.new_event_loop())
@pytest.fixture
def _teardown() -> any:
yield
DB_PATH.unlink()
@pytest.fixture
def _create_v1_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Version(Base):
__tablename__ = "version"
version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
version: Mapped[int] = mapped_column(type_=INTEGER)
class Location(Base):
__tablename__ = "location"
people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
latitude: Mapped[float] = mapped_column(type_=REAL)
longitude: Mapped[float] = mapped_column(type_=REAL)
altitude: Mapped[float] = mapped_column(type_=REAL)
engine = create_engine(db)
Base.metadata.create_all(engine)
with engine.begin() as conn:
conn.execute(text("PRAGMA user_version = 1"))
@pytest.fixture
def _create_v2_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
latitude: Mapped[float] = mapped_column(type_=REAL)
longitude: Mapped[float] = mapped_column(type_=REAL)
altitude: Mapped[float] = mapped_column(type_=REAL)
engine = create_engine(db)
Base.metadata.create_all(engine)
with engine.begin() as conn:
conn.execute(text("PRAGMA user_version = 2"))
@pytest.fixture
def _create_latest_db() -> None:
db = "sqlite:///" + DB_PATH_STR
class Base(DeclarativeBase):
pass
class Location(Base):
__tablename__ = "location"
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
engine = create_engine(db)
Base.metadata.create_all(engine)
@pytest.mark.usefixtures("_create_v1_db")
@pytest.mark.usefixtures("_teardown")
def test_migration_1_latest() -> None:
nr_tables_ver_1 = 2
table_ver_1_0 = "version"
nr_column_ver_1_version = 2
table_ver_1_1 = "location"
nr_column_ver_1_location = 5
nr_tables_ver_2 = 1
table_ver_2_0 = "location"
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == 1
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_1
assert tables[0][0] == table_ver_1_0
assert tables[1][0] == table_ver_1_1
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_0})")
table_info_version = sqlite3_cursor.fetchall()
assert len(table_info_version) == nr_column_ver_1_version
assert table_info_version[0] == (0, "version_type", "TEXT", 1, None, 1)
assert table_info_version[1] == (1, "version", "INTEGER", 1, None, 0)
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_1})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_1_location
assert table_info_location[0] == (0, "people", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
asyncio.run(location_recorder.create_db_engine())
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_2
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_1_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
sqlite3_cursor.close()
@pytest.mark.usefixtures("_create_v2_db")
@pytest.mark.usefixtures("_teardown")
def test_migration_2_latest() -> None:
nr_tables_ver_2 = 1
table_ver_2_0 = "location"
nr_column_ver_2_location = 5
nr_tables_ver_3 = 1
table_ver_3_0 = "location"
nr_column_ver_3_location = 5
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == 2 # noqa: PLR2004
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_2
assert tables[0][0] == table_ver_2_0
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_2_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
asyncio.run(location_recorder.create_db_engine())
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
tables = sqlite3_cursor.fetchall()
assert len(tables) == nr_tables_ver_3
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_3_0})")
table_info_location = sqlite3_cursor.fetchall()
assert len(table_info_location) == nr_column_ver_3_location
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_teardown")
def test_create_db() -> None:
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
event_loop.run_until_complete(location_recorder.dispose_db_engine())
assert DB_PATH.exists()
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("PRAGMA user_version")
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_inser_location_utc() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
date_time = datetime.now(tz=UTC)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location(
person=person,
date_time=date_time,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
assert location[1] == date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_inser_location_other() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
tz = ZoneInfo("Asia/Shanghai")
date_time = datetime.now(tz=tz)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location(
person=person,
date_time=date_time,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
assert location[1] == date_time.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_insert_location_now() -> None:
latitude = 1.0
longitude = 2.0
altitude = 3.0
person = "test_person"
date_time = datetime.now(tz=UTC)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
event_loop.run_until_complete(
location_recorder.insert_location_now(
person=person,
location=location_data,
),
)
event_loop.run_until_complete(location_recorder.dispose_db_engine())
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
location = sqlite3_cursor.fetchone()
assert location[0] == person
date_time_act = datetime.strptime(location[1], "%Y-%m-%dT%H:%M:%S%z")
assert date_time.date() == date_time_act.date()
assert location[2] == latitude
assert location[3] == longitude
assert location[4] == altitude
sqlite3_cursor.close()
@pytest.mark.usefixtures("_reset_event_loop")
@pytest.mark.usefixtures("_create_latest_db")
@pytest.mark.usefixtures("_teardown")
def test_insert_locations() -> None:
locations: dict[datetime, LocationData] = {}
person = "Tianyu"
time_0 = datetime.now(tz=UTC)
lat_0 = 1.0
lon_0 = 2.0
alt_0 = 3.0
time_1 = datetime(2021, 8, 30, 10, 20, 15, tzinfo=UTC)
lat_1 = 155.0
lon_1 = 33.36
alt_1 = 1058
locations[time_0] = LocationData(lat_0, lon_0, alt_0)
locations[time_1] = LocationData(lat_1, lon_1, alt_1)
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(location_recorder.create_db_engine())
event_loop.run_until_complete(
location_recorder.insert_locations(person=person, locations=locations),
)
sqlite3_db = sqlite3.connect(DB_PATH_STR)
sqlite3_cursor = sqlite3_db.cursor()
sqlite3_cursor.execute("SELECT * FROM location")
locations = sqlite3_cursor.fetchall()
assert len(locations) == 2 # noqa: PLR2004
assert locations[0][0] == person
assert locations[0][1] == time_0.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert locations[0][2] == lat_0
assert locations[0][3] == lon_0
assert locations[0][4] == alt_0
assert locations[1][0] == person
assert locations[1][1] == time_1.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
assert locations[1][2] == lat_1
assert locations[1][3] == lon_1
assert locations[1][4] == alt_1
sqlite3_cursor.close()
-84
View File
@@ -1,84 +0,0 @@
from __future__ import annotations
import urllib.parse
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from datetime import datetime
import httpx
from src.config import Config
class TickTick:
@dataclass
class Task:
projectId: str # noqa: N815
title: str
dueDate: str | None = None # noqa: N815
content: str | None = None
desc: str | None = None
def __init__(self) -> None:
print("Initializing TickTick...")
if Config.get_env("TICKTICK_ACCESS_TOKEN") is None:
self._begin_auth()
else:
self._access_token = Config.get_env("TICKTICK_ACCESS_TOKEN")
def _begin_auth(self) -> None:
ticktick_code_auth_url = "https://ticktick.com/oauth/authorize?"
ticktick_code_auth_params = {
"client_id": Config.get_env("TICKTICK_CLIENT_ID"),
"scope": "tasks:read tasks:write",
"state": "begin_auth",
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
"response_type": "code",
}
ticktick_auth_url_encoded = urllib.parse.urlencode(ticktick_code_auth_params)
print("Visit: ", ticktick_code_auth_url + ticktick_auth_url_encoded, " to authenticate.")
async def retrieve_access_token(self, code: str, state: str) -> bool:
if state != "begin_auth":
print("Invalid state.")
return False
ticktick_token_url = "https://ticktick.com/oauth/token" # noqa: S105
ticktick_token_auth_params: dict[str, str] = {
"code": code,
"grant_type": "authorization_code",
"scope": "tasks:write tasks:read",
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
}
client_id = Config.get_env("TICKTICK_CLIENT_ID")
client_secret = Config.get_env("TICKTICK_CLIENT_SECRET")
response = await httpx.AsyncClient().post(
ticktick_token_url,
data=ticktick_token_auth_params,
auth=httpx.BasicAuth(username=client_id, password=client_secret),
timeout=10,
)
Config.update_env("TICKTICK_ACCESS_TOKEN", response.json()["access_token"])
return True
async def get_tasks(self, project_id: str) -> list[dict]:
ticktick_get_tasks_url = "https://api.ticktick.com/open/v1/project/" + project_id + "/data"
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
response = await httpx.AsyncClient().get(ticktick_get_tasks_url, headers=header, timeout=10)
return response.json()["tasks"]
async def has_duplicate_task(self, project_id: str, task_title: str) -> bool:
tasks = await self.get_tasks(project_id=project_id)
return any(task["title"] == task_title for task in tasks)
async def create_task(self, task: TickTick.Task) -> dict[str, str]:
if not await self.has_duplicate_task(project_id=task.projectId, task_title=task.title):
ticktick_task_creation_url = "https://api.ticktick.com/open/v1/task"
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
await httpx.AsyncClient().post(ticktick_task_creation_url, headers=header, json=asdict(task), timeout=10)
return {"title": task.title}
@staticmethod
def datetime_to_ticktick_format(datetime: datetime) -> str:
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "+0000"