rename
This commit is contained in:
0
src/util/__init__.py
Normal file
0
src/util/__init__.py
Normal file
56
src/util/homeassistant.py
Normal file
56
src/util/homeassistant.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.config import Config
|
||||
from src.util.ticktick import TickTick
|
||||
|
||||
|
||||
class HomeAssistant:
|
||||
class Message(BaseModel):
|
||||
target: str
|
||||
action: str
|
||||
content: str
|
||||
|
||||
def __init__(self, ticktick: TickTick) -> None:
|
||||
self._ticktick = ticktick
|
||||
|
||||
async def process_message(self, message: Message) -> dict[str, str]:
|
||||
if message.target == "ticktick":
|
||||
return await self._process_ticktick_message(message=message)
|
||||
|
||||
return {"Status": "Unknown target"}
|
||||
|
||||
async def trigger_webhook(self, payload: dict[str, str], webhook_id: str) -> None:
|
||||
token: str = Config.get_env("HOMEASSISTANT_TOKEN")
|
||||
webhook_url: str = Config.get_env("HOMEASSISTANT_URL") + "/api/webhook/" + webhook_id
|
||||
headers: dict[str, str] = {"Authorization": f"Bearer {token}"}
|
||||
await httpx.AsyncClient().post(webhook_url, json=payload, headers=headers)
|
||||
|
||||
async def _process_ticktick_message(self, message: Message) -> dict[str, str]:
|
||||
if message.action == "create_shopping_list":
|
||||
return await self._create_shopping_list(content=message.content)
|
||||
if message.action == "create_action_task":
|
||||
return await self._create_action_task(content=message.content)
|
||||
|
||||
return {"Status": "Unknown action"}
|
||||
|
||||
async def _create_shopping_list(self, content: str) -> dict[str, str]:
|
||||
project_id = Config.get_env("TICKTICK_SHOPPING_LIST")
|
||||
item: dict[str, str] = ast.literal_eval(content)
|
||||
task = TickTick.Task(projectId=project_id, title=item["item"])
|
||||
return await self._ticktick.create_task(task=task)
|
||||
|
||||
async def _create_action_task(self, content: str) -> dict[str, str]:
|
||||
detail: dict[str, str] = ast.literal_eval(content)
|
||||
project_id = Config.get_env("TICKTICK_HOME_TASK_LIST")
|
||||
due_hour = detail["due_hour"]
|
||||
due = datetime.now(tz=datetime.now().astimezone().tzinfo) + timedelta(hours=due_hour)
|
||||
due = (due + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
due = due.astimezone(timezone.utc)
|
||||
task = TickTick.Task(projectId=project_id, title=detail["action"], dueDate=TickTick.datetime_to_ticktick_format(due))
|
||||
return await self._ticktick.create_task(task=task)
|
||||
73
src/util/mqtt.py
Normal file
73
src/util/mqtt.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi_mqtt import FastMQTT, MQTTConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MQTTSubscription:
|
||||
topic: str
|
||||
callback: callable
|
||||
subscribed: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class MQTTPendingMessage:
|
||||
topic: str
|
||||
payload: dict
|
||||
retain: bool
|
||||
|
||||
|
||||
class MQTT:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106
|
||||
self._mqtt = FastMQTT(config=self._mqtt_config, client_id="home_automation_backend")
|
||||
self._mqtt.mqtt_handlers.user_connect_handler = self.on_connect
|
||||
self._mqtt.mqtt_handlers.user_message_handler = self.on_message
|
||||
self._connected = False
|
||||
self._subscribed_topic: dict[str, MQTTSubscription] = {}
|
||||
self._queued_message: queue.Queue[MQTTPendingMessage] = queue.Queue()
|
||||
|
||||
async def start(self) -> None:
|
||||
print("MQTT Starting...")
|
||||
await self._mqtt.mqtt_startup()
|
||||
|
||||
async def stop(self) -> None:
|
||||
print("MQTT Stopping...")
|
||||
await self._mqtt.mqtt_shutdown()
|
||||
|
||||
def on_connect(self, client, flags, rc, properties) -> None: # noqa: ANN001, ARG002
|
||||
print("Connected")
|
||||
self._connected = True
|
||||
while not self._queued_message.empty():
|
||||
msg = self._queued_message.get(block=False)
|
||||
self.publish(msg.topic, msg.payload, retain=msg.retain)
|
||||
for topic, subscription in self._subscribed_topic.items():
|
||||
if subscription.subscribed is False:
|
||||
self.subscribe(topic, subscription.callback)
|
||||
|
||||
async def on_message(self, client, topic: str, payload: bytes, qos: int, properties: any) -> any: # noqa: ANN001, ARG002
|
||||
print("On message")
|
||||
if topic in self._subscribed_topic and self._subscribed_topic[topic].callback is not None:
|
||||
await self._subscribed_topic[topic].callback(payload)
|
||||
|
||||
def subscribe(self, topic: str, callback: callable) -> None:
|
||||
if self._connected:
|
||||
print("Subscribe to topic: ", topic)
|
||||
self._mqtt.client.subscribe(topic)
|
||||
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=True)
|
||||
else:
|
||||
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=False)
|
||||
|
||||
def publish(self, topic: str, payload: dict, *, retain: bool) -> None:
|
||||
if self._connected:
|
||||
self._mqtt.publish(topic, payload=payload, retain=retain)
|
||||
else:
|
||||
self._queued_message.put(MQTTPendingMessage(topic, payload, retain=retain))
|
||||
82
src/util/notion.py
Normal file
82
src/util/notion.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
from notion_client import AsyncClient as Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class Text:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RichText:
|
||||
type: str
|
||||
href: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RichTextText(RichText):
|
||||
type: str = "text"
|
||||
text: Text = field(default_factory=lambda: Text(content=""))
|
||||
|
||||
|
||||
class NotionAsync:
|
||||
def __init__(self, token: str) -> None:
|
||||
self._client = Client(auth=token)
|
||||
|
||||
def update_token(self, token: str) -> None:
|
||||
self._client.aclose()
|
||||
self._client = Client(auth=token)
|
||||
|
||||
async def get_block(self, block_id: str) -> dict:
|
||||
return await self._client.blocks.retrieve(block_id=block_id)
|
||||
|
||||
async def get_block_children(self, block_id: str, start_cursor: str | None = None, page_size: int = 100) -> dict:
|
||||
return await self._client.blocks.children.list(
|
||||
block_id=block_id,
|
||||
start_cursor=start_cursor,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
async def block_is_table(self, block_id: str) -> bool:
|
||||
block: dict = await self.get_block(block_id=block_id)
|
||||
return block["type"] == "table"
|
||||
|
||||
async def get_table_width(self, table_id: str) -> int:
|
||||
table = await self._client.blocks.retrieve(block_id=table_id)
|
||||
return table["table"]["table_width"]
|
||||
|
||||
async def append_table_row_text(self, table_id: str, text_list: list[str], after: str | None = None) -> None:
|
||||
cells: list[RichText] = []
|
||||
for content in text_list:
|
||||
cells.append([asdict(RichTextText(text=Text(content)))]) # noqa: PERF401
|
||||
await self.append_table_row(table_id=table_id, cells=cells, after=after)
|
||||
|
||||
async def append_table_row(self, table_id: str, cells: list[RichText], after: str | None = None) -> None:
|
||||
if not await self.block_is_table(table_id):
|
||||
return
|
||||
table_width = await self.get_table_width(table_id=table_id)
|
||||
if table_width != len(cells):
|
||||
return
|
||||
children = [
|
||||
{
|
||||
"object": "block",
|
||||
"type": "table_row",
|
||||
"table_row": {
|
||||
"cells": cells,
|
||||
},
|
||||
},
|
||||
]
|
||||
if after is None:
|
||||
await self._client.blocks.children.append(
|
||||
block_id=table_id,
|
||||
children=children,
|
||||
)
|
||||
else:
|
||||
await self._client.blocks.children.append(
|
||||
block_id=table_id,
|
||||
children=children,
|
||||
after=after,
|
||||
)
|
||||
0
src/util/tests/__init__.py
Normal file
0
src/util/tests/__init__.py
Normal file
7
src/util/tests/test_ticktick.py
Normal file
7
src/util/tests/test_ticktick.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from util.ticktick import TickTick
|
||||
|
||||
|
||||
def test_ticktick_begin_auth() -> None:
|
||||
auth_url = TickTick.begin_auth()
|
||||
84
src/util/ticktick.py
Normal file
84
src/util/ticktick.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from src.config import Config
|
||||
|
||||
|
||||
class TickTick:
|
||||
@dataclass
|
||||
class Task:
|
||||
projectId: str # noqa: N815
|
||||
title: str
|
||||
dueDate: str | None = None # noqa: N815
|
||||
content: str | None = None
|
||||
desc: str | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
print("Initializing TickTick...")
|
||||
if Config.get_env("TICKTICK_ACCESS_TOKEN") is None:
|
||||
self._begin_auth()
|
||||
else:
|
||||
self._access_token = Config.get_env("TICKTICK_ACCESS_TOKEN")
|
||||
|
||||
def _begin_auth(self) -> None:
|
||||
ticktick_code_auth_url = "https://ticktick.com/oauth/authorize?"
|
||||
ticktick_code_auth_params = {
|
||||
"client_id": Config.get_env("TICKTICK_CLIENT_ID"),
|
||||
"scope": "tasks:read tasks:write",
|
||||
"state": "begin_auth",
|
||||
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
|
||||
"response_type": "code",
|
||||
}
|
||||
ticktick_auth_url_encoded = urllib.parse.urlencode(ticktick_code_auth_params)
|
||||
print("Visit: ", ticktick_code_auth_url + ticktick_auth_url_encoded, " to authenticate.")
|
||||
|
||||
async def retrieve_access_token(self, code: str, state: str) -> bool:
|
||||
if state != "begin_auth":
|
||||
print("Invalid state.")
|
||||
return False
|
||||
ticktick_token_url = "https://ticktick.com/oauth/token" # noqa: S105
|
||||
ticktick_token_auth_params: dict[str, str] = {
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"scope": "tasks:write tasks:read",
|
||||
"redirect_uri": Config.get_env("TICKTICK_CODE_REDIRECT_URI"),
|
||||
}
|
||||
client_id = Config.get_env("TICKTICK_CLIENT_ID")
|
||||
client_secret = Config.get_env("TICKTICK_CLIENT_SECRET")
|
||||
response = await httpx.AsyncClient().post(
|
||||
ticktick_token_url,
|
||||
data=ticktick_token_auth_params,
|
||||
auth=httpx.BasicAuth(username=client_id, password=client_secret),
|
||||
timeout=10,
|
||||
)
|
||||
Config.update_env("TICKTICK_ACCESS_TOKEN", response.json()["access_token"])
|
||||
return True
|
||||
|
||||
async def get_tasks(self, project_id: str) -> list[dict]:
|
||||
ticktick_get_tasks_url = "https://api.ticktick.com/open/v1/project/" + project_id + "/data"
|
||||
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
|
||||
response = await httpx.AsyncClient().get(ticktick_get_tasks_url, headers=header, timeout=10)
|
||||
return response.json()["tasks"]
|
||||
|
||||
async def has_duplicate_task(self, project_id: str, task_title: str) -> bool:
|
||||
tasks = await self.get_tasks(project_id=project_id)
|
||||
return any(task["title"] == task_title for task in tasks)
|
||||
|
||||
async def create_task(self, task: TickTick.Task) -> dict[str, str]:
|
||||
if not await self.has_duplicate_task(project_id=task.projectId, task_title=task.title):
|
||||
ticktick_task_creation_url = "https://api.ticktick.com/open/v1/task"
|
||||
header: dict[str, str] = {"Authorization": f"Bearer {self._access_token}"}
|
||||
await httpx.AsyncClient().post(ticktick_task_creation_url, headers=header, json=asdict(task), timeout=10)
|
||||
return {"title": task.title}
|
||||
|
||||
@staticmethod
|
||||
def datetime_to_ticktick_format(datetime: datetime) -> str:
|
||||
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "+0000"
|
||||
Reference in New Issue
Block a user