From e1e6e0f7d779133bc6d01acf64d967e318cabe98 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 7 Aug 2024 16:58:06 +0200 Subject: [PATCH] Drag out mqtt and make a new class --- src/cloud_util/mqtt.py | 73 ++++++++++++++++++++++++++++++++++++++++++ src/main.py | 11 ++++--- src/recorder/poo.py | 48 +++++++++------------------ 3 files changed, 94 insertions(+), 38 deletions(-) create mode 100644 src/cloud_util/mqtt.py diff --git a/src/cloud_util/mqtt.py b/src/cloud_util/mqtt.py new file mode 100644 index 0000000..f651608 --- /dev/null +++ b/src/cloud_util/mqtt.py @@ -0,0 +1,73 @@ +import queue +from dataclasses import dataclass + +from fastapi_mqtt import FastMQTT, MQTTConfig + + +@dataclass +class MQTTSubscription: + topic: str + callback: callable + subscribed: bool + + +@dataclass +class MQTTPendingMessage: + topic: str + payload: dict + retain: bool + + +class MQTT: + _instance = None + + def __new__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204 + if not cls._instance: + cls._instance = super().__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self) -> None: + self._mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106 + self._mqtt = FastMQTT(config=self._mqtt_config, client_id="home_automation_backend") + self._mqtt.mqtt_handlers.user_connect_handler = self.on_connect + self._mqtt.mqtt_handlers.user_message_handler = self.on_message + self._connected = False + self._subscribed_topic: dict[str, MQTTSubscription] = {} + self._queued_message: queue.Queue[MQTTPendingMessage] = queue.Queue() + + async def start(self) -> None: + print("MQTT Starting...") + await self._mqtt.mqtt_startup() + + async def stop(self) -> None: + print("MQTT Stopping...") + await self._mqtt.mqtt_shutdown() + + def on_connect(self, client, flags, rc, properties) -> None: # noqa: ANN001, ARG002 + print("Connected") + self._connected = True + while not self._queued_message.empty(): + msg = self._queued_message.get(block=False) + self.publish(msg.topic, msg.payload, retain=msg.retain) + for topic, subscription in self._subscribed_topic.items(): + if subscription.subscribed is False: + self.subscribe(topic, subscription.callback) + + async def on_message(self, client, topic: str, payload: bytes, qos: int, properties: any) -> any: # noqa: ANN001, ARG002 + print("On message") + if topic in self._subscribed_topic and self._subscribed_topic[topic].callback is not None: + await self._subscribed_topic[topic].callback(payload) + + def subscribe(self, topic: str, callback: callable) -> None: + if self._connected: + print("Subscribe to topic: ", topic) + self._mqtt.client.subscribe(topic) + self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=True) + else: + self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=False) + + def publish(self, topic: str, payload: dict, *, retain: bool) -> None: + if self._connected: + self._mqtt.publish(topic, payload=payload, retain=retain) + else: + self._queued_message.put(MQTTPendingMessage(topic, payload, retain=retain)) diff --git a/src/main.py b/src/main.py index 06ef850..d6f1b0f 100644 --- a/src/main.py +++ b/src/main.py @@ -3,16 +3,18 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from pydantic import BaseModel +from src.cloud_util.mqtt import MQTT from src.recorder.poo import PooRecorder -recorder = PooRecorder() +mqtt = MQTT() +poo_recorder = PooRecorder(mqtt) @asynccontextmanager async def _lifespan(_app: FastAPI): # noqa: ANN202 - await recorder.start() + await mqtt.start() yield - await recorder.stop() + await mqtt.stop() class PooRecordField(BaseModel): @@ -24,6 +26,5 @@ app = FastAPI(lifespan=_lifespan) @app.post("/poo/record") async def record(record_detail: PooRecordField) -> PooRecordField: - print(record_detail.status) - await recorder.record(record_detail.status) + await poo_recorder.record(record_detail.status) return record_detail diff --git a/src/recorder/poo.py b/src/recorder/poo.py index 7fc91a6..400c124 100644 --- a/src/recorder/poo.py +++ b/src/recorder/poo.py @@ -1,14 +1,11 @@ from datetime import datetime -from fastapi_mqtt import FastMQTT, MQTTConfig - +from src.cloud_util.mqtt import MQTT from src.recorder.notion_handle import NotionClient class PooRecorder: - mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106 notion = NotionClient() - mqtt = FastMQTT(config=mqtt_config, client_id="poo_recorder") CONFIG_TOPIC = "homeassistant/text/poo_recorder/config" AVAILABILITY_TOPIC = "studiotj/poo_recorder/status" COMMAND_TOPIC = "studiotj/poo_recorder/update_text" @@ -17,42 +14,27 @@ class PooRecorder: ONLINE = "online" OFFLINE = "offline" - def __init__(self) -> None: - print("Initialization.") - - async def start(self) -> None: - print("Starting...") - await PooRecorder.mqtt.mqtt_startup() - - async def stop(self) -> None: - print("Stopping...") - await PooRecorder.mqtt.mqtt_shutdown() + def __init__(self, mqtt: MQTT) -> None: + print("Poo Recorder Initialization...") + self._mqtt = mqtt + self._mqtt.publish(PooRecorder.CONFIG_TOPIC, PooRecorder.compose_config(), retain=True) + self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True) async def record(self, status: str) -> None: - PooRecorder.publish_text(status) + self._publish_text(status) now = datetime.now(tz=datetime.now().astimezone().tzinfo) - PooRecorder.publish_time(now) + self._publish_time(now) await PooRecorder.notion.note(now, status) - @staticmethod - @mqtt.on_connect() - def on_connect(client, flags, rc, properties) -> None: # noqa: ANN001, ARG004 - print("Connected") - config = PooRecorder.compose_config() - PooRecorder.mqtt.publish(PooRecorder.CONFIG_TOPIC, config, retain=True) - PooRecorder.mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True) + def _publish_text(self, new_text: str) -> None: + self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True) + self._mqtt.publish(PooRecorder.STATE_TOPIC, new_text, retain=True) - @staticmethod - def publish_text(new_text: str) -> None: - PooRecorder.mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True) - PooRecorder.mqtt.publish(PooRecorder.STATE_TOPIC, new_text, retain=True) - - @staticmethod - def publish_time(time: datetime) -> None: + def _publish_time(self, time: datetime) -> None: formatted_time = time.strftime("%a | %Y-%m-%d | %H:%M") - PooRecorder.mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True) + self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True) json_string = {"last_poo": formatted_time} - PooRecorder.mqtt.publish(PooRecorder.JSON_TOPIC, json_string, retain=True) + self._mqtt.publish(PooRecorder.JSON_TOPIC, json_string, retain=True) @staticmethod def compose_config() -> dict: @@ -60,7 +42,7 @@ class PooRecorder: "device": { "name": "Dog Poop Recorder", "model": "poop-recorder-backend", - "sw_version": "1.2", + "sw_version": "1.3", "identifiers": ["poo_recorder"], "manufacturer": "Studio TJ", },