Start with go version
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,78 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.config import Config
|
||||
from src.util.homeassistant import HomeAssistant
|
||||
from src.util.mqtt import MQTT
|
||||
from src.util.notion import NotionAsync
|
||||
|
||||
|
||||
class PooRecorder:
|
||||
CONFIG_TOPIC = "homeassistant/text/poo_recorder/config"
|
||||
AVAILABILITY_TOPIC = "studiotj/poo_recorder/status"
|
||||
COMMAND_TOPIC = "studiotj/poo_recorder/update_text"
|
||||
STATE_TOPIC = "studiotj/poo_recorder/text"
|
||||
JSON_TOPIC = "studiotj/poo_recorder/attributes"
|
||||
ONLINE = "online"
|
||||
OFFLINE = "offline"
|
||||
|
||||
class RecordField(BaseModel):
|
||||
status: str
|
||||
latitude: str
|
||||
longitude: str
|
||||
|
||||
def __init__(self, mqtt: MQTT, notion: NotionAsync, homeassistant: HomeAssistant) -> None:
|
||||
print("Poo Recorder Initialization...")
|
||||
self._notion = notion
|
||||
self._table_id = Config.get_env("POO_RECORD_NOTION_TABLE_ID")
|
||||
self._mqtt = mqtt
|
||||
self._mqtt.publish(PooRecorder.CONFIG_TOPIC, PooRecorder.compose_config(), retain=True)
|
||||
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
|
||||
self._homeassistant = homeassistant
|
||||
|
||||
async def _note(self, now: datetime, status: str, latitude: str, longitude: str) -> None:
|
||||
formatted_date = now.strftime("%Y-%m-%d")
|
||||
formatted_time = now.strftime("%H:%M")
|
||||
status.strip()
|
||||
await self._notion.append_table_row_text(self._table_id, [formatted_date, formatted_time, status, latitude + "," + longitude])
|
||||
|
||||
async def record(self, record_detail: RecordField) -> None:
|
||||
webhook_id: str = Config.get_env("HOMEASSISTANT_POO_TRIGGER_ID")
|
||||
self._publish_text(record_detail.status)
|
||||
now = datetime.now(tz=datetime.now().astimezone().tzinfo)
|
||||
self._publish_time(now)
|
||||
await self._note(now, record_detail.status, record_detail.latitude, record_detail.longitude)
|
||||
await self._homeassistant.trigger_webhook(payload={"status": record_detail.status}, webhook_id=webhook_id)
|
||||
|
||||
def _publish_text(self, new_text: str) -> None:
|
||||
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
|
||||
self._mqtt.publish(PooRecorder.STATE_TOPIC, new_text, retain=True)
|
||||
|
||||
def _publish_time(self, time: datetime) -> None:
|
||||
formatted_time = time.strftime("%a | %Y-%m-%d | %H:%M")
|
||||
self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
|
||||
json_string = {"last_poo": formatted_time}
|
||||
self._mqtt.publish(PooRecorder.JSON_TOPIC, json_string, retain=True)
|
||||
|
||||
@staticmethod
|
||||
def compose_config() -> dict:
|
||||
return {
|
||||
"device": {
|
||||
"name": "Dog Poop Recorder",
|
||||
"model": "poop-recorder-backend",
|
||||
"sw_version": Config.VERSION,
|
||||
"identifiers": ["poo_recorder"],
|
||||
"manufacturer": "Studio TJ",
|
||||
},
|
||||
"unique_id": "poo_recorder",
|
||||
"name": "Poo Status",
|
||||
"availability_topic": PooRecorder.AVAILABILITY_TOPIC,
|
||||
"availability_template": "{{ value_json.availability }}",
|
||||
"json_attributes_topic": PooRecorder.JSON_TOPIC,
|
||||
"min": 0,
|
||||
"max": 255,
|
||||
"mode": "text",
|
||||
"command_topic": PooRecorder.COMMAND_TOPIC,
|
||||
"state_topic": PooRecorder.STATE_TOPIC,
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from dotenv import dotenv_values, set_key, unset_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import OrderedDict
|
||||
|
||||
config_path = Path(__file__).parent.resolve()
|
||||
DOT_ENV_PATH = Path(config_path, ".env")
|
||||
DOT_ENV_PATH.touch(mode=0o600, exist_ok=True)
|
||||
|
||||
|
||||
class Config:
|
||||
env_dict: ClassVar[OrderedDict[str, str]] = {}
|
||||
dot_env_path = DOT_ENV_PATH
|
||||
VERSION = "2.0"
|
||||
|
||||
@staticmethod
|
||||
def init(dotenv_path: str = DOT_ENV_PATH) -> None:
|
||||
Config.dot_env_path = dotenv_path
|
||||
Config.env_dict = dotenv_values(dotenv_path=dotenv_path)
|
||||
|
||||
@staticmethod
|
||||
def get_env(key: str) -> str | None:
|
||||
if key in Config.env_dict:
|
||||
return Config.env_dict[key]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def update_env(key: str, value: str) -> None:
|
||||
set_key(Config.dot_env_path, key, value)
|
||||
Config.env_dict = dotenv_values(dotenv_path=Config.dot_env_path)
|
||||
|
||||
@staticmethod
|
||||
def remove_env(key: str) -> None:
|
||||
unset_key(Config.dot_env_path, key)
|
||||
Config.env_dict = dotenv_values(dotenv_path=Config.dot_env_path)
|
||||
11
src/go.mod
Normal file
11
src/go.mod
Normal file
@@ -0,0 +1,11 @@
|
||||
module github.com/t-liu93/home-automation-backend
|
||||
|
||||
go 1.23.0
|
||||
|
||||
require github.com/jomei/notionapi v1.13.2
|
||||
|
||||
require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/spf13/cobra v1.8.1 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
)
|
||||
12
src/go.sum
Normal file
12
src/go.sum
Normal file
@@ -0,0 +1,12 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jomei/notionapi v1.13.2 h1:YpHKNpkoTMlUfWTlVIodOmQDgRKjfwmtSNVa6/6yC9E=
|
||||
github.com/jomei/notionapi v1.13.2/go.mod h1:BqzP6JBddpBnXvMSIxiR5dCoCjKngmz5QNl1ONDlDoM=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -1,61 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
current_file_path = Path(__file__).resolve().parent
|
||||
sys.path.append(str(current_file_path / ".." / ".." / ".."))
|
||||
from src.util.location_recorder import LocationData, LocationRecorder # noqa: E402
|
||||
|
||||
# Create an argument parser
|
||||
parser = argparse.ArgumentParser(description="Google Location Reader")
|
||||
|
||||
# Add an argument for the JSON file path
|
||||
parser.add_argument("--json-file", type=str, help="Path to the JSON file")
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
json_file_path: str = args.json_file
|
||||
|
||||
db_path = current_file_path / ".." / ".." / ".." / "temp_data" / "test.db"
|
||||
location_recorder = LocationRecorder(db_path=str(db_path))
|
||||
|
||||
# Open the JSON file
|
||||
with Path.open(json_file_path) as json_file:
|
||||
data = json.load(json_file)
|
||||
|
||||
|
||||
locations: list[dict] = data["locations"]
|
||||
print(type(locations), len(locations))
|
||||
|
||||
|
||||
async def insert() -> None:
|
||||
nr_waypoints = 0
|
||||
await location_recorder.create_db_engine()
|
||||
locations_dict: dict[datetime, LocationData] = {}
|
||||
for location in locations:
|
||||
nr_waypoints += 1
|
||||
try:
|
||||
latitude: float = location["latitudeE7"] / 1e7
|
||||
longitude: float = location["longitudeE7"] / 1e7
|
||||
except KeyError:
|
||||
continue
|
||||
altitude: float = location.get("altitude", None)
|
||||
try:
|
||||
date_time = datetime.strptime(location["timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
except ValueError:
|
||||
date_time = datetime.strptime(location["timestamp"], "%Y-%m-%dT%H:%M:%S%z")
|
||||
locations_dict[date_time] = LocationData(
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
altitude=altitude,
|
||||
)
|
||||
await location_recorder.insert_locations("Tianyu", locations=locations_dict)
|
||||
print(nr_waypoints)
|
||||
await location_recorder.dispose_db_engine()
|
||||
|
||||
|
||||
asyncio.run(insert())
|
||||
@@ -1,44 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from datetime import UTC
|
||||
from pathlib import Path
|
||||
|
||||
import gpxpy
|
||||
import gpxpy.gpx
|
||||
|
||||
current_file_path = Path(__file__).resolve().parent
|
||||
sys.path.append(str(current_file_path / ".." / ".." / ".."))
|
||||
from src.util.location_recorder import LocationData, LocationRecorder # noqa: E402
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPX Location Reader")
|
||||
|
||||
parser.add_argument("--gpx-file", type=str, help="Path to the GPX file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
gpx_location = args.gpx_file
|
||||
|
||||
gpx_file = Path.open(gpx_location)
|
||||
gpx = gpxpy.parse(gpx_file)
|
||||
|
||||
db_path = current_file_path / ".." / ".." / ".." / "temp_data" / "test.db"
|
||||
location_recorder = LocationRecorder(db_path=str(db_path))
|
||||
|
||||
|
||||
async def iterate_and_insert() -> None:
|
||||
nr_waypoints = 0
|
||||
await location_recorder.create_db_engine()
|
||||
for track in gpx.tracks:
|
||||
for segment in track.segments:
|
||||
for point in segment.points:
|
||||
nr_waypoints += 1
|
||||
print(f"Point at ({point.latitude},{point.longitude}) -> {point.time}")
|
||||
point.time = point.time.replace(tzinfo=UTC)
|
||||
location_data = LocationData(latitude=point.latitude, longitude=point.longitude, altitude=point.elevation)
|
||||
await location_recorder.insert_location(person="Tianyu", date_time=point.time, location=location_data)
|
||||
await location_recorder.dispose_db_engine()
|
||||
print(nr_waypoints)
|
||||
|
||||
|
||||
asyncio.run(iterate_and_insert())
|
||||
@@ -1,41 +0,0 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from src.config import Config
|
||||
from src.util.notion import NotionAsync
|
||||
|
||||
Config.init()
|
||||
|
||||
notion = NotionAsync(token=Config.get_env("NOTION_TOKEN"))
|
||||
|
||||
current_file_path = Path(__file__).resolve()
|
||||
|
||||
current_dir = str(current_file_path.parent)
|
||||
|
||||
|
||||
rows: list[str] = []
|
||||
|
||||
|
||||
async def update_rows() -> None:
|
||||
header: dict = await notion.get_block_children(block_id=Config.get_env("POO_RECORD_NOTION_TABLE_ID"), page_size=1)
|
||||
header_id = header["results"][0]["id"]
|
||||
with Path.open(current_dir + "/../../../temp_data/old_poo_record.txt") as file:
|
||||
content = file.read()
|
||||
rows = content.split("\n")
|
||||
rows.reverse()
|
||||
|
||||
for row in rows:
|
||||
t = row[0:5]
|
||||
date = row[8:19]
|
||||
formatted_date = datetime.datetime.strptime(date, "%a, %d %b").astimezone().replace(year=2024).strftime("%Y-%m-%d")
|
||||
status = row[20:]
|
||||
print(f"{formatted_date} {t} {status}")
|
||||
await notion.append_table_row_text(
|
||||
table_id=Config.get_env("POO_RECORD_NOTION_TABLE_ID"),
|
||||
text_list=[formatted_date, t, status, "0,0"],
|
||||
after=header_id,
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(update_rows())
|
||||
87
src/helper/poo_recorder_helper/cmd/reverse.go
Normal file
87
src/helper/poo_recorder_helper/cmd/reverse.go
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
Copyright © 2024 Tianyu Liu
|
||||
*/
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jomei/notionapi"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/t-liu93/home-automation-backend/util/notion"
|
||||
)
|
||||
|
||||
var notionToken string
|
||||
var notionTableId string
|
||||
|
||||
// reverseCmd represents the reverse command
|
||||
var reverseCmd = &cobra.Command{
|
||||
Use: "reverse",
|
||||
Short: "Reverse given poo recording table",
|
||||
Run: reverseRun,
|
||||
}
|
||||
|
||||
func reverseRun(cmd *cobra.Command, args []string) {
|
||||
rows := []notionapi.Block{}
|
||||
fmt.Println("Reverse table ID: ", notionTableId)
|
||||
notion.Init(notionToken)
|
||||
headerBlock := <-notion.AGetBlockChildren(notionTableId, "", 100)
|
||||
headerId := headerBlock.Children[0].GetID()
|
||||
nextCursor := headerId.String()
|
||||
hasMore := true
|
||||
for hasMore {
|
||||
blockChildren := <-notion.AGetBlockChildren(notionTableId, nextCursor, 100)
|
||||
rows = append(rows, blockChildren.Children...)
|
||||
hasMore = blockChildren.HasMore
|
||||
nextCursor = blockChildren.NextCursor
|
||||
}
|
||||
rows = rows[1:]
|
||||
rowsR := reverseTable(rows)
|
||||
for _, row := range rowsR {
|
||||
notion.GetClient().Block.Delete(context.Background(), row.GetID())
|
||||
fmt.Println("Deleted row: ", row.GetID())
|
||||
time.Sleep(400 * time.Millisecond)
|
||||
}
|
||||
after := headerId
|
||||
for len(rowsR) > 0 {
|
||||
var rowsToWrite []notionapi.Block
|
||||
if len(rowsR) > 100 {
|
||||
rowsToWrite = rowsR[:100]
|
||||
} else {
|
||||
rowsToWrite = rowsR
|
||||
}
|
||||
res, err := notion.GetClient().Block.AppendChildren(context.Background(), notionapi.BlockID(notionTableId), ¬ionapi.AppendBlockChildrenRequest{
|
||||
After: after,
|
||||
Children: rowsToWrite,
|
||||
})
|
||||
after = rowsToWrite[len(rowsToWrite)-1].GetID()
|
||||
rowsR = rowsR[len(rowsToWrite):]
|
||||
fmt.Println(res, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func reverseTable[T any](rows []T) []T {
|
||||
for i, j := 0, len(rows)-1; i < j; i, j = i+1, j-1 {
|
||||
rows[i], rows[j] = rows[j], rows[i]
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(reverseCmd)
|
||||
|
||||
// Here you will define your flags and configuration settings.
|
||||
|
||||
// Cobra supports Persistent Flags which will work for this command
|
||||
// and all subcommands, e.g.:
|
||||
// reverseCmd.PersistentFlags().String("foo", "", "A help for foo")
|
||||
|
||||
// Cobra supports local flags which will only run when this command
|
||||
// is called directly, e.g.:
|
||||
// reverseCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
|
||||
reverseCmd.Flags().StringVar(¬ionToken, "token", "", "Notion API token")
|
||||
reverseCmd.Flags().StringVar(¬ionTableId, "table-id", "", "Notion table id to reverse")
|
||||
}
|
||||
51
src/helper/poo_recorder_helper/cmd/root.go
Normal file
51
src/helper/poo_recorder_helper/cmd/root.go
Normal file
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
Copyright © 2024 Tianyu Liu
|
||||
|
||||
*/
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
||||
|
||||
// rootCmd represents the base command when called without any subcommands
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "poo_recorder_helper",
|
||||
Short: "A brief description of your application",
|
||||
Long: `A longer description that spans multiple lines and likely contains
|
||||
examples and usage of using your application. For example:
|
||||
|
||||
Cobra is a CLI library for Go that empowers applications.
|
||||
This application is a tool to generate the needed files
|
||||
to quickly create a Cobra application.`,
|
||||
// Uncomment the following line if your bare application
|
||||
// has an action associated with it:
|
||||
// Run: func(cmd *cobra.Command, args []string) { },
|
||||
}
|
||||
|
||||
// Execute adds all child commands to the root command and sets flags appropriately.
|
||||
// This is called by main.main(). It only needs to happen once to the rootCmd.
|
||||
func Execute() {
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Here you will define your flags and configuration settings.
|
||||
// Cobra supports persistent flags, which, if defined here,
|
||||
// will be global for your application.
|
||||
|
||||
// rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.poo_recorder_helper.yaml)")
|
||||
|
||||
// Cobra also supports local flags, which will only run
|
||||
// when this action is called directly.
|
||||
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
|
||||
}
|
||||
|
||||
|
||||
11
src/helper/poo_recorder_helper/main.go
Normal file
11
src/helper/poo_recorder_helper/main.go
Normal file
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
Copyright © 2024 Tianyu Liu
|
||||
|
||||
*/
|
||||
package main
|
||||
|
||||
import "github.com/t-liu93/home-automation-backend/helper/poo_recorder_helper/cmd"
|
||||
|
||||
func main() {
|
||||
cmd.Execute()
|
||||
}
|
||||
60
src/main.py
60
src/main.py
@@ -1,60 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from src.components.poo_recorder import PooRecorder
|
||||
from src.config import Config
|
||||
from src.util.homeassistant import HomeAssistant
|
||||
from src.util.location_recorder import LocationRecorder
|
||||
from src.util.mqtt import MQTT
|
||||
from src.util.notion import NotionAsync
|
||||
from src.util.ticktick import TickTick
|
||||
|
||||
Config.init()
|
||||
|
||||
location_recorder_db = str(Path(__file__).resolve().parent / ".." / "location_recorder.db")
|
||||
|
||||
ticktick = TickTick()
|
||||
notion = NotionAsync(token=Config.get_env(key="NOTION_TOKEN"))
|
||||
mqtt = MQTT()
|
||||
location_recorder = LocationRecorder(db_path=location_recorder_db)
|
||||
homeassistant = HomeAssistant(ticktick=ticktick, location_recorder=location_recorder)
|
||||
poo_recorder = PooRecorder(mqtt=mqtt, notion=notion, homeassistant=homeassistant)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan(_app: FastAPI): # noqa: ANN202
|
||||
await mqtt.start()
|
||||
await location_recorder.create_db_engine()
|
||||
yield
|
||||
await mqtt.stop()
|
||||
await location_recorder.dispose_db_engine()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=_lifespan)
|
||||
|
||||
|
||||
@app.get("/homeassistant/status")
|
||||
async def get_status() -> dict:
|
||||
return {"Status": "Ok"}
|
||||
|
||||
|
||||
@app.post("/homeassistant/publish")
|
||||
async def homeassistant_publish(payload: HomeAssistant.Message) -> dict:
|
||||
return await homeassistant.process_message(message=payload)
|
||||
|
||||
|
||||
# Poo recorder
|
||||
@app.post("/poo/record")
|
||||
async def record(record_detail: PooRecorder.RecordField) -> PooRecorder.RecordField:
|
||||
await poo_recorder.record(record_detail)
|
||||
return record_detail
|
||||
|
||||
|
||||
# ticktick
|
||||
@app.get("/ticktick/auth/code")
|
||||
async def ticktick_auth(code: str, state: str) -> dict:
|
||||
if await ticktick.retrieve_access_token(code, state):
|
||||
return {"State": "Token Retrieved"}
|
||||
return {"State": "Token Retrieval Failed"}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,70 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from dotenv import dotenv_values, set_key
|
||||
|
||||
from src.config import Config
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent.resolve()
|
||||
TEST_DOT_ENV_PATH = Path(CONFIG_PATH, ".env_test")
|
||||
|
||||
EXPECTED_ENV_DICT: OrderedDict[str, str] = OrderedDict(
|
||||
{
|
||||
"KEY_1": "VALUE_1",
|
||||
"KEY_2": "VALUE_2",
|
||||
"NOTION_TOKEN": "1234454_234324",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _prepare_test_dot_env() -> any:
|
||||
TEST_DOT_ENV_PATH.touch(mode=0o600, exist_ok=True)
|
||||
for key, value in EXPECTED_ENV_DICT.items():
|
||||
set_key(TEST_DOT_ENV_PATH, key, value)
|
||||
yield
|
||||
TEST_DOT_ENV_PATH.unlink()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _load_test_dot_env(_prepare_test_dot_env: any) -> None:
|
||||
Config.init(dotenv_path=TEST_DOT_ENV_PATH)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_prepare_test_dot_env")
|
||||
def test_init_config() -> None:
|
||||
assert Config.env_dict == {}
|
||||
Config.init(dotenv_path=TEST_DOT_ENV_PATH)
|
||||
assert Config.env_dict == EXPECTED_ENV_DICT
|
||||
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
|
||||
assert dict_from_file == EXPECTED_ENV_DICT
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_load_test_dot_env")
|
||||
def test_get_config() -> None:
|
||||
assert Config.get_env("NON_EXISTING_KEY") is None
|
||||
key_1 = "KEY_1"
|
||||
assert Config.get_env(key_1) == EXPECTED_ENV_DICT[key_1]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_load_test_dot_env")
|
||||
def test_update_config() -> None:
|
||||
key = "KEY_1"
|
||||
value = EXPECTED_ENV_DICT[key]
|
||||
new_value = "NEW_VALUE"
|
||||
assert Config.get_env(key) == value
|
||||
Config.update_env(key, new_value)
|
||||
assert Config.get_env(key) == new_value
|
||||
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
|
||||
assert dict_from_file[key] == new_value
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_load_test_dot_env")
|
||||
def test_remove_config() -> None:
|
||||
key = "KEY_1"
|
||||
assert Config.get_env(key) == EXPECTED_ENV_DICT[key]
|
||||
Config.remove_env(key)
|
||||
assert Config.get_env(key) is None
|
||||
dict_from_file = dotenv_values(dotenv_path=TEST_DOT_ENV_PATH)
|
||||
assert key not in dict_from_file
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,71 +0,0 @@
|
||||
import ast
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.config import Config
|
||||
from src.util.location_recorder import LocationData, LocationRecorder
|
||||
from src.util.ticktick import TickTick
|
||||
|
||||
|
||||
class HomeAssistant:
|
||||
class Message(BaseModel):
|
||||
target: str
|
||||
action: str
|
||||
content: str
|
||||
|
||||
def __init__(self, ticktick: TickTick, location_recorder: LocationRecorder) -> None:
|
||||
self._ticktick = ticktick
|
||||
self._location_recorder = location_recorder
|
||||
|
||||
async def process_message(self, message: Message) -> dict[str, str]:
|
||||
if message.target == "ticktick":
|
||||
return await self._process_ticktick_message(message=message)
|
||||
if message.target == "location_recorder":
|
||||
return await self._process_location(message=message)
|
||||
return {"Status": "Unknown target"}
|
||||
|
||||
async def trigger_webhook(self, payload: dict[str, str], webhook_id: str) -> None:
|
||||
token: str = Config.get_env("HOMEASSISTANT_TOKEN")
|
||||
webhook_url: str = Config.get_env("HOMEASSISTANT_URL") + "/api/webhook/" + webhook_id
|
||||
headers: dict[str, str] = {"Authorization": f"Bearer {token}"}
|
||||
await httpx.AsyncClient().post(webhook_url, json=payload, headers=headers)
|
||||
|
||||
async def _process_ticktick_message(self, message: Message) -> dict[str, str]:
|
||||
if message.action == "create_shopping_list":
|
||||
return await self._create_shopping_list(content=message.content)
|
||||
if message.action == "create_action_task":
|
||||
return await self._create_action_task(content=message.content)
|
||||
|
||||
return {"Status": "Unknown action"}
|
||||
|
||||
async def _process_location(self, message: Message) -> dict[str, str]:
|
||||
if message.action == "record":
|
||||
location: dict[str, str] = ast.literal_eval(message.content)
|
||||
await self._location_recorder.insert_location_now(
|
||||
person=location["person"],
|
||||
location=LocationData(
|
||||
latitude=float(location["latitude"]),
|
||||
longitude=float(location["longitude"]),
|
||||
altitude=float(location["altitude"]),
|
||||
),
|
||||
)
|
||||
return {"Status": "Location recorded"}
|
||||
return {"Status": "Unknown action"}
|
||||
|
||||
async def _create_shopping_list(self, content: str) -> dict[str, str]:
|
||||
project_id = Config.get_env("TICKTICK_SHOPPING_LIST")
|
||||
item: dict[str, str] = ast.literal_eval(content)
|
||||
task = TickTick.Task(projectId=project_id, title=item["item"])
|
||||
return await self._ticktick.create_task(task=task)
|
||||
|
||||
async def _create_action_task(self, content: str) -> dict[str, str]:
|
||||
detail: dict[str, str] = ast.literal_eval(content)
|
||||
project_id = Config.get_env("TICKTICK_HOME_TASK_LIST")
|
||||
due_hour = detail["due_hour"]
|
||||
due = datetime.now(tz=datetime.now().astimezone().tzinfo) + timedelta(hours=due_hour)
|
||||
due = (due + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
due = due.astimezone(timezone.utc)
|
||||
task = TickTick.Task(projectId=project_id, title=detail["action"], dueDate=TickTick.datetime_to_ticktick_format(due))
|
||||
return await self._ticktick.create_task(task=task)
|
||||
@@ -1,120 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import REAL, TEXT, insert, text
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Location(Base):
|
||||
__tablename__ = "location"
|
||||
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
|
||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
|
||||
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
|
||||
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
|
||||
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocationData:
|
||||
latitude: float
|
||||
longitude: float
|
||||
altitude: float
|
||||
|
||||
|
||||
class LocationRecorder:
|
||||
USER_VERSION = 3
|
||||
|
||||
def __init__(self, db_path: str) -> None:
|
||||
self._db_path = "sqlite+aiosqlite:///" + db_path
|
||||
|
||||
async def create_db_engine(self) -> None:
|
||||
self._engine = create_async_engine(self._db_path)
|
||||
async with self._engine.begin() as conn:
|
||||
user_version = await self._get_user_version(conn=conn)
|
||||
if user_version == 0:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await self._set_user_version(conn=conn, user_version=2)
|
||||
if user_version != LocationRecorder.USER_VERSION:
|
||||
await self._migrate(conn=conn)
|
||||
|
||||
async def dispose_db_engine(self) -> None:
|
||||
await self._engine.dispose()
|
||||
|
||||
async def insert_location(self, person: str, date_time: datetime, location: LocationData) -> None:
|
||||
if date_time.tzinfo != UTC:
|
||||
date_time = date_time.astimezone(UTC)
|
||||
date_time_str = date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.execute(
|
||||
insert(Location)
|
||||
.prefix_with("OR IGNORE")
|
||||
.values(
|
||||
person=person,
|
||||
datetime=date_time_str,
|
||||
latitude=location.latitude,
|
||||
longitude=location.longitude,
|
||||
altitude=location.altitude,
|
||||
),
|
||||
)
|
||||
|
||||
async def insert_locations(self, person: str, locations: dict[datetime, LocationData]) -> None:
|
||||
async with self._engine.begin() as conn:
|
||||
for k, v in locations.items():
|
||||
dt = k
|
||||
if k.tzinfo != UTC:
|
||||
dt = k.astimezone(UTC)
|
||||
date_time_str = dt.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
await conn.execute(
|
||||
insert(Location)
|
||||
.prefix_with("OR IGNORE")
|
||||
.values(
|
||||
person=person,
|
||||
datetime=date_time_str,
|
||||
latitude=v.latitude,
|
||||
longitude=v.longitude,
|
||||
altitude=v.altitude,
|
||||
),
|
||||
)
|
||||
|
||||
async def insert_location_now(self, person: str, location: LocationData) -> None:
|
||||
now_utc = datetime.now(tz=UTC)
|
||||
await self.insert_location(person, now_utc, location)
|
||||
|
||||
async def _get_user_version(self, conn: AsyncConnection) -> int:
|
||||
return (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||
|
||||
async def _set_user_version(self, conn: AsyncConnection, user_version: int) -> None:
|
||||
await conn.execute(text("PRAGMA user_version = " + str(user_version)))
|
||||
|
||||
async def _migrate(self, conn: AsyncConnection) -> None:
|
||||
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||
if user_version == 1:
|
||||
await self._migrate_1_2(conn=conn)
|
||||
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||
if user_version == 2: # noqa: PLR2004
|
||||
await self._migrate_2_3(conn=conn)
|
||||
user_version = (await conn.execute(text("PRAGMA user_version"))).first()[0]
|
||||
|
||||
async def _migrate_1_2(self, conn: AsyncConnection) -> None:
|
||||
print("Location Recorder: migrate from db ver 1 to 2.")
|
||||
await conn.execute(text("DROP TABLE version"))
|
||||
await conn.execute(text("ALTER TABLE location RENAME people TO person"))
|
||||
await self._set_user_version(conn=conn, user_version=2)
|
||||
|
||||
async def _migrate_2_3(self, conn: AsyncConnection) -> None:
|
||||
print("Location Recorder: migrate from db ver 2 to 3.")
|
||||
await conn.execute(text("ALTER TABLE location RENAME TO location_old"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await conn.execute(
|
||||
text("""
|
||||
INSERT INTO location (person, datetime, latitude, longitude, altitude)
|
||||
SELECT person, datetime, latitude, longitude, altitude FROM location_old;
|
||||
"""),
|
||||
)
|
||||
await conn.execute(text("DROP TABLE location_old"))
|
||||
await self._set_user_version(conn=conn, user_version=3)
|
||||
@@ -1,73 +0,0 @@
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi_mqtt import FastMQTT, MQTTConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MQTTSubscription:
|
||||
topic: str
|
||||
callback: callable
|
||||
subscribed: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class MQTTPendingMessage:
|
||||
topic: str
|
||||
payload: dict
|
||||
retain: bool
|
||||
|
||||
|
||||
class MQTT:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106
|
||||
self._mqtt = FastMQTT(config=self._mqtt_config, client_id="home_automation_backend")
|
||||
self._mqtt.mqtt_handlers.user_connect_handler = self.on_connect
|
||||
self._mqtt.mqtt_handlers.user_message_handler = self.on_message
|
||||
self._connected = False
|
||||
self._subscribed_topic: dict[str, MQTTSubscription] = {}
|
||||
self._queued_message: queue.Queue[MQTTPendingMessage] = queue.Queue()
|
||||
|
||||
async def start(self) -> None:
|
||||
print("MQTT Starting...")
|
||||
await self._mqtt.mqtt_startup()
|
||||
|
||||
async def stop(self) -> None:
|
||||
print("MQTT Stopping...")
|
||||
await self._mqtt.mqtt_shutdown()
|
||||
|
||||
def on_connect(self, client, flags, rc, properties) -> None: # noqa: ANN001, ARG002
|
||||
print("Connected")
|
||||
self._connected = True
|
||||
while not self._queued_message.empty():
|
||||
msg = self._queued_message.get(block=False)
|
||||
self.publish(msg.topic, msg.payload, retain=msg.retain)
|
||||
for topic, subscription in self._subscribed_topic.items():
|
||||
if subscription.subscribed is False:
|
||||
self.subscribe(topic, subscription.callback)
|
||||
|
||||
async def on_message(self, client, topic: str, payload: bytes, qos: int, properties: any) -> any: # noqa: ANN001, ARG002
|
||||
print("On message")
|
||||
if topic in self._subscribed_topic and self._subscribed_topic[topic].callback is not None:
|
||||
await self._subscribed_topic[topic].callback(payload)
|
||||
|
||||
def subscribe(self, topic: str, callback: callable) -> None:
|
||||
if self._connected:
|
||||
print("Subscribe to topic: ", topic)
|
||||
self._mqtt.client.subscribe(topic)
|
||||
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=True)
|
||||
else:
|
||||
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=False)
|
||||
|
||||
def publish(self, topic: str, payload: dict, *, retain: bool) -> None:
|
||||
if self._connected:
|
||||
self._mqtt.publish(topic, payload=payload, retain=retain)
|
||||
else:
|
||||
self._queued_message.put(MQTTPendingMessage(topic, payload, retain=retain))
|
||||
@@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
from notion_client import AsyncClient as Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class Text:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RichText:
|
||||
type: str
|
||||
href: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RichTextText(RichText):
|
||||
type: str = "text"
|
||||
text: Text = field(default_factory=lambda: Text(content=""))
|
||||
|
||||
|
||||
class NotionAsync:
|
||||
def __init__(self, token: str) -> None:
|
||||
self._client = Client(auth=token)
|
||||
|
||||
def update_token(self, token: str) -> None:
|
||||
self._client.aclose()
|
||||
self._client = Client(auth=token)
|
||||
|
||||
async def get_block(self, block_id: str) -> dict:
|
||||
return await self._client.blocks.retrieve(block_id=block_id)
|
||||
|
||||
async def get_block_children(self, block_id: str, start_cursor: str | None = None, page_size: int = 100) -> dict:
|
||||
return await self._client.blocks.children.list(
|
||||
block_id=block_id,
|
||||
start_cursor=start_cursor,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
async def block_is_table(self, block_id: str) -> bool:
|
||||
block: dict = await self.get_block(block_id=block_id)
|
||||
return block["type"] == "table"
|
||||
|
||||
async def get_table_width(self, table_id: str) -> int:
|
||||
table = await self._client.blocks.retrieve(block_id=table_id)
|
||||
return table["table"]["table_width"]
|
||||
|
||||
async def append_table_row_text(self, table_id: str, text_list: list[str], after: str | None = None) -> None:
|
||||
cells: list[RichText] = []
|
||||
for content in text_list:
|
||||
cells.append([asdict(RichTextText(text=Text(content)))]) # noqa: PERF401
|
||||
await self.append_table_row(table_id=table_id, cells=cells, after=after)
|
||||
|
||||
async def append_table_row(self, table_id: str, cells: list[RichText], after: str | None = None) -> None:
|
||||
if not await self.block_is_table(table_id):
|
||||
return
|
||||
table_width = await self.get_table_width(table_id=table_id)
|
||||
if table_width != len(cells):
|
||||
return
|
||||
children = [
|
||||
{
|
||||
"object": "block",
|
||||
"type": "table_row",
|
||||
"table_row": {
|
||||
"cells": cells,
|
||||
},
|
||||
},
|
||||
]
|
||||
if after is None:
|
||||
await self._client.blocks.children.append(
|
||||
block_id=table_id,
|
||||
children=children,
|
||||
)
|
||||
else:
|
||||
await self._client.blocks.children.append(
|
||||
block_id=table_id,
|
||||
children=children,
|
||||
after=after,
|
||||
)
|
||||
62
src/util/notion/notion.go
Normal file
62
src/util/notion/notion.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package notion
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jomei/notionapi"
|
||||
)
|
||||
|
||||
var client *notionapi.Client
|
||||
var authToken string
|
||||
|
||||
func Init(token string) {
|
||||
authToken = token
|
||||
client = notionapi.NewClient(notionapi.Token(token))
|
||||
}
|
||||
|
||||
func GetClient() *notionapi.Client {
|
||||
return client
|
||||
}
|
||||
|
||||
func AGetBlock(blockId string) chan struct {
|
||||
Block notionapi.Block
|
||||
Error error
|
||||
} {
|
||||
retval := make(chan struct {
|
||||
Block notionapi.Block
|
||||
Error error
|
||||
})
|
||||
go func() {
|
||||
block, err := client.Block.Get(context.Background(), notionapi.BlockID(blockId))
|
||||
retval <- struct {
|
||||
Block notionapi.Block
|
||||
Error error
|
||||
}{block, err}
|
||||
}()
|
||||
return retval
|
||||
}
|
||||
|
||||
func AGetBlockChildren(blockId string, startCursor string, pageSize int) chan struct {
|
||||
Children []notionapi.Block
|
||||
NextCursor string
|
||||
HasMore bool
|
||||
Error error
|
||||
} {
|
||||
retval := make(chan struct {
|
||||
Children []notionapi.Block
|
||||
NextCursor string
|
||||
HasMore bool
|
||||
Error error
|
||||
})
|
||||
pagination := notionapi.Pagination{StartCursor: notionapi.Cursor(startCursor), PageSize: pageSize}
|
||||
go func() {
|
||||
children, err := client.Block.GetChildren(context.Background(), notionapi.BlockID(blockId), &pagination)
|
||||
retval <- struct {
|
||||
Children []notionapi.Block
|
||||
NextCursor string
|
||||
HasMore bool
|
||||
Error error
|
||||
}{children.Results, children.NextCursor, children.HasMore, err}
|
||||
}()
|
||||
return retval
|
||||
}
|
||||
27
src/util/notion/notion_test.go
Normal file
27
src/util/notion/notion_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package notion_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/t-liu93/home-automation-backend/util/notion"
|
||||
)
|
||||
|
||||
func TestGetBlockBasic(t *testing.T) {
|
||||
notion.Init("")
|
||||
block := <-notion.AGetBlock("")
|
||||
actObj := block.Block.GetObject()
|
||||
expObj := "block"
|
||||
if string(actObj) != expObj {
|
||||
t.Errorf("Expected %s, but got %s", expObj, actObj)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBlockChildrenBasic(t *testing.T) {
|
||||
notion.Init("")
|
||||
blockChildren := <-notion.AGetBlockChildren("", "", 100)
|
||||
actLen := len(blockChildren.Children)
|
||||
expLen := 100
|
||||
if actLen != expLen {
|
||||
t.Errorf("Expected %d, but got %d", expLen, actLen)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,354 +0,0 @@
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import INTEGER, REAL, TEXT, create_engine, text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from src.util.location_recorder import LocationData, LocationRecorder
|
||||
|
||||
DB_PATH = Path(__file__).resolve().parent / "test.db"
|
||||
DB_PATH_STR = str(DB_PATH)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _reset_event_loop() -> any:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.stop()
|
||||
loop.close()
|
||||
except RuntimeError:
|
||||
pass
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _teardown() -> any:
|
||||
yield
|
||||
DB_PATH.unlink()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _create_v1_db() -> None:
|
||||
db = "sqlite:///" + DB_PATH_STR
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Version(Base):
|
||||
__tablename__ = "version"
|
||||
version_type: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
version: Mapped[int] = mapped_column(type_=INTEGER)
|
||||
|
||||
class Location(Base):
|
||||
__tablename__ = "location"
|
||||
people: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
latitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
longitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
altitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
|
||||
engine = create_engine(db)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("PRAGMA user_version = 1"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _create_v2_db() -> None:
|
||||
db = "sqlite:///" + DB_PATH_STR
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Location(Base):
|
||||
__tablename__ = "location"
|
||||
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True)
|
||||
latitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
longitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
altitude: Mapped[float] = mapped_column(type_=REAL)
|
||||
|
||||
engine = create_engine(db)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("PRAGMA user_version = 2"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _create_latest_db() -> None:
|
||||
db = "sqlite:///" + DB_PATH_STR
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Location(Base):
|
||||
__tablename__ = "location"
|
||||
person: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
|
||||
datetime: Mapped[str] = mapped_column(type_=TEXT, primary_key=True, nullable=False)
|
||||
latitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
|
||||
longitude: Mapped[float] = mapped_column(type_=REAL, nullable=False)
|
||||
altitude: Mapped[float] = mapped_column(type_=REAL, nullable=True)
|
||||
|
||||
engine = create_engine(db)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_create_v1_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_migration_1_latest() -> None:
|
||||
nr_tables_ver_1 = 2
|
||||
table_ver_1_0 = "version"
|
||||
nr_column_ver_1_version = 2
|
||||
table_ver_1_1 = "location"
|
||||
nr_column_ver_1_location = 5
|
||||
nr_tables_ver_2 = 1
|
||||
table_ver_2_0 = "location"
|
||||
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == 1
|
||||
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
tables = sqlite3_cursor.fetchall()
|
||||
assert len(tables) == nr_tables_ver_1
|
||||
assert tables[0][0] == table_ver_1_0
|
||||
assert tables[1][0] == table_ver_1_1
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_0})")
|
||||
table_info_version = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_version) == nr_column_ver_1_version
|
||||
assert table_info_version[0] == (0, "version_type", "TEXT", 1, None, 1)
|
||||
assert table_info_version[1] == (1, "version", "INTEGER", 1, None, 0)
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_1_1})")
|
||||
table_info_location = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_location) == nr_column_ver_1_location
|
||||
assert table_info_location[0] == (0, "people", "TEXT", 1, None, 1)
|
||||
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
|
||||
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
asyncio.run(location_recorder.create_db_engine())
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
|
||||
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
tables = sqlite3_cursor.fetchall()
|
||||
assert len(tables) == nr_tables_ver_2
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
|
||||
table_info_location = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_location) == nr_column_ver_1_location
|
||||
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
|
||||
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_create_v2_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_migration_2_latest() -> None:
|
||||
nr_tables_ver_2 = 1
|
||||
table_ver_2_0 = "location"
|
||||
nr_column_ver_2_location = 5
|
||||
nr_tables_ver_3 = 1
|
||||
table_ver_3_0 = "location"
|
||||
nr_column_ver_3_location = 5
|
||||
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == 2 # noqa: PLR2004
|
||||
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
tables = sqlite3_cursor.fetchall()
|
||||
assert len(tables) == nr_tables_ver_2
|
||||
assert tables[0][0] == table_ver_2_0
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_2_0})")
|
||||
table_info_location = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_location) == nr_column_ver_2_location
|
||||
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
|
||||
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[4] == (4, "altitude", "REAL", 1, None, 0)
|
||||
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
asyncio.run(location_recorder.create_db_engine())
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
|
||||
sqlite3_cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
tables = sqlite3_cursor.fetchall()
|
||||
assert len(tables) == nr_tables_ver_3
|
||||
sqlite3_cursor.execute(f"PRAGMA table_info({table_ver_3_0})")
|
||||
table_info_location = sqlite3_cursor.fetchall()
|
||||
assert len(table_info_location) == nr_column_ver_3_location
|
||||
assert table_info_location[0] == (0, "person", "TEXT", 1, None, 1)
|
||||
assert table_info_location[1] == (1, "datetime", "TEXT", 1, None, 2)
|
||||
assert table_info_location[2] == (2, "latitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[3] == (3, "longitude", "REAL", 1, None, 0)
|
||||
assert table_info_location[4] == (4, "altitude", "REAL", 0, None, 0)
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_create_db() -> None:
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
event_loop.run_until_complete(location_recorder.dispose_db_engine())
|
||||
assert DB_PATH.exists()
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("PRAGMA user_version")
|
||||
assert sqlite3_cursor.fetchone()[0] == LocationRecorder.USER_VERSION
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_create_latest_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_inser_location_utc() -> None:
|
||||
latitude = 1.0
|
||||
longitude = 2.0
|
||||
altitude = 3.0
|
||||
person = "test_person"
|
||||
date_time = datetime.now(tz=UTC)
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
|
||||
event_loop.run_until_complete(
|
||||
location_recorder.insert_location(
|
||||
person=person,
|
||||
date_time=date_time,
|
||||
location=location_data,
|
||||
),
|
||||
)
|
||||
event_loop.run_until_complete(location_recorder.dispose_db_engine())
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("SELECT * FROM location")
|
||||
location = sqlite3_cursor.fetchone()
|
||||
assert location[0] == person
|
||||
assert location[1] == date_time.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
assert location[2] == latitude
|
||||
assert location[3] == longitude
|
||||
assert location[4] == altitude
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_create_latest_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_inser_location_other() -> None:
|
||||
latitude = 1.0
|
||||
longitude = 2.0
|
||||
altitude = 3.0
|
||||
person = "test_person"
|
||||
tz = ZoneInfo("Asia/Shanghai")
|
||||
date_time = datetime.now(tz=tz)
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
|
||||
event_loop.run_until_complete(
|
||||
location_recorder.insert_location(
|
||||
person=person,
|
||||
date_time=date_time,
|
||||
location=location_data,
|
||||
),
|
||||
)
|
||||
event_loop.run_until_complete(location_recorder.dispose_db_engine())
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("SELECT * FROM location")
|
||||
location = sqlite3_cursor.fetchone()
|
||||
assert location[0] == person
|
||||
assert location[1] == date_time.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
assert location[2] == latitude
|
||||
assert location[3] == longitude
|
||||
assert location[4] == altitude
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_create_latest_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_insert_location_now() -> None:
|
||||
latitude = 1.0
|
||||
longitude = 2.0
|
||||
altitude = 3.0
|
||||
person = "test_person"
|
||||
date_time = datetime.now(tz=UTC)
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
location_data = LocationData(latitude=latitude, longitude=longitude, altitude=altitude)
|
||||
event_loop.run_until_complete(
|
||||
location_recorder.insert_location_now(
|
||||
person=person,
|
||||
location=location_data,
|
||||
),
|
||||
)
|
||||
event_loop.run_until_complete(location_recorder.dispose_db_engine())
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("SELECT * FROM location")
|
||||
location = sqlite3_cursor.fetchone()
|
||||
assert location[0] == person
|
||||
date_time_act = datetime.strptime(location[1], "%Y-%m-%dT%H:%M:%S%z")
|
||||
assert date_time.date() == date_time_act.date()
|
||||
assert location[2] == latitude
|
||||
assert location[3] == longitude
|
||||
assert location[4] == altitude
|
||||
sqlite3_cursor.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_reset_event_loop")
|
||||
@pytest.mark.usefixtures("_create_latest_db")
|
||||
@pytest.mark.usefixtures("_teardown")
|
||||
def test_insert_locations() -> None:
|
||||
locations: dict[datetime, LocationData] = {}
|
||||
person = "Tianyu"
|
||||
time_0 = datetime.now(tz=UTC)
|
||||
lat_0 = 1.0
|
||||
lon_0 = 2.0
|
||||
alt_0 = 3.0
|
||||
time_1 = datetime(2021, 8, 30, 10, 20, 15, tzinfo=UTC)
|
||||
lat_1 = 155.0
|
||||
lon_1 = 33.36
|
||||
alt_1 = 1058
|
||||
locations[time_0] = LocationData(lat_0, lon_0, alt_0)
|
||||
locations[time_1] = LocationData(lat_1, lon_1, alt_1)
|
||||
location_recorder = LocationRecorder(db_path=DB_PATH_STR)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(location_recorder.create_db_engine())
|
||||
event_loop.run_until_complete(
|
||||
location_recorder.insert_locations(person=person, locations=locations),
|
||||
)
|
||||
sqlite3_db = sqlite3.connect(DB_PATH_STR)
|
||||
sqlite3_cursor = sqlite3_db.cursor()
|
||||
sqlite3_cursor.execute("SELECT * FROM location")
|
||||
locations = sqlite3_cursor.fetchall()
|
||||
assert len(locations) == 2 # noqa: PLR2004
|
||||
assert locations[0][0] == person
|
||||
assert locations[0][1] == time_0.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
assert locations[0][2] == lat_0
|
||||
assert locations[0][3] == lon_0
|
||||
assert locations[0][4] == alt_0
|
||||
assert locations[1][0] == person
|
||||
assert locations[1][1] == time_1.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
assert locations[1][2] == lat_1
|
||||
assert locations[1][3] == lon_1
|
||||
assert locations[1][4] == alt_1
|
||||
sqlite3_cursor.close()
|
||||
@@ -1,84 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from src.config import Config
|
||||
|
||||
|
||||
class TickTick:
|
||||
@dataclass
|
||||
class Task:
|
||||
projectId: str # noqa: N815
|
||||
title: str
|
||||
dueDate: str | None = None # noqa: N815
|
||||
content: str | None = None
|
||||
desc: str | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
print("Initializing TickTick...")
|
||||
if Config.get_env("TICKTICK_ACCESS_TOKEN") is None:
|
||||
self._begin_auth()
|
||||
else:
|
||||
self._access_token = Config.get_env("TICKTICK_ACCESS_TOKEN")
|
||||
|
||||
def _begin_auth(self) -> None:
|
||||
ticktick_code_auth_url = "https://ticktick.com/oauth/authorize?"
|
||||
ticktick_code_auth_params = {
|
||||
"client_id": Config.get_env("TICKTICK_CLIENT_ID"),
|
||||
"scope": "tasks:read tasks:write",
|
||||
"state": "begin_auth",
|
||||
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
|
||||
"response_type": "code",
|
||||
}
|
||||
ticktick_auth_url_encoded = urllib.parse.urlencode(ticktick_code_auth_params)
|
||||
print("Visit: ", ticktick_code_auth_url + ticktick_auth_url_encoded, " to authenticate.")
|
||||
|
||||
async def retrieve_access_token(self, code: str, state: str) -> bool:
|
||||
if state != "begin_auth":
|
||||
print("Invalid state.")
|
||||
return False
|
||||
ticktick_token_url = "https://ticktick.com/oauth/token" # noqa: S105
|
||||
ticktick_token_auth_params: dict[str, str] = {
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"scope": "tasks:write tasks:read",
|
||||
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
|
||||
}
|
||||
client_id = Config.get_env("TICKTICK_CLIENT_ID")
|
||||
client_secret = Config.get_env("TICKTICK_CLIENT_SECRET")
|
||||
response = await httpx.AsyncClient().post(
|
||||
ticktick_token_url,
|
||||
data=ticktick_token_auth_params,
|
||||
auth=httpx.BasicAuth(username=client_id, password=client_secret),
|
||||
timeout=10,
|
||||
)
|
||||
Config.update_env("TICKTICK_ACCESS_TOKEN", response.json()["access_token"])
|
||||
return True
|
||||
|
||||
async def get_tasks(self, project_id: str) -> list[dict]:
|
||||
ticktick_get_tasks_url = "https://api.ticktick.com/open/v1/project/" + project_id + "/data"
|
||||
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
|
||||
response = await httpx.AsyncClient().get(ticktick_get_tasks_url, headers=header, timeout=10)
|
||||
return response.json()["tasks"]
|
||||
|
||||
async def has_duplicate_task(self, project_id: str, task_title: str) -> bool:
|
||||
tasks = await self.get_tasks(project_id=project_id)
|
||||
return any(task["title"] == task_title for task in tasks)
|
||||
|
||||
async def create_task(self, task: TickTick.Task) -> dict[str, str]:
|
||||
if not await self.has_duplicate_task(project_id=task.projectId, task_title=task.title):
|
||||
ticktick_task_creation_url = "https://api.ticktick.com/open/v1/task"
|
||||
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
|
||||
await httpx.AsyncClient().post(ticktick_task_creation_url, headers=header, json=asdict(task), timeout=10)
|
||||
return {"title": task.title}
|
||||
|
||||
@staticmethod
|
||||
def datetime_to_ticktick_format(datetime: datetime) -> str:
|
||||
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "+0000"
|
||||
Reference in New Issue
Block a user