Drag out mqtt and make a new class

This commit is contained in:
2024-08-07 16:58:06 +02:00
parent bc365ffe7a
commit e1e6e0f7d7
3 changed files with 94 additions and 38 deletions

73
src/cloud_util/mqtt.py Normal file
View File

@@ -0,0 +1,73 @@
import queue
from dataclasses import dataclass
from fastapi_mqtt import FastMQTT, MQTTConfig
@dataclass
class MQTTSubscription:
topic: str
callback: callable
subscribed: bool
@dataclass
class MQTTPendingMessage:
topic: str
payload: dict
retain: bool
class MQTT:
_instance = None
def __new__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
if not cls._instance:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self) -> None:
self._mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106
self._mqtt = FastMQTT(config=self._mqtt_config, client_id="home_automation_backend")
self._mqtt.mqtt_handlers.user_connect_handler = self.on_connect
self._mqtt.mqtt_handlers.user_message_handler = self.on_message
self._connected = False
self._subscribed_topic: dict[str, MQTTSubscription] = {}
self._queued_message: queue.Queue[MQTTPendingMessage] = queue.Queue()
async def start(self) -> None:
print("MQTT Starting...")
await self._mqtt.mqtt_startup()
async def stop(self) -> None:
print("MQTT Stopping...")
await self._mqtt.mqtt_shutdown()
def on_connect(self, client, flags, rc, properties) -> None: # noqa: ANN001, ARG002
print("Connected")
self._connected = True
while not self._queued_message.empty():
msg = self._queued_message.get(block=False)
self.publish(msg.topic, msg.payload, retain=msg.retain)
for topic, subscription in self._subscribed_topic.items():
if subscription.subscribed is False:
self.subscribe(topic, subscription.callback)
async def on_message(self, client, topic: str, payload: bytes, qos: int, properties: any) -> any: # noqa: ANN001, ARG002
print("On message")
if topic in self._subscribed_topic and self._subscribed_topic[topic].callback is not None:
await self._subscribed_topic[topic].callback(payload)
def subscribe(self, topic: str, callback: callable) -> None:
if self._connected:
print("Subscribe to topic: ", topic)
self._mqtt.client.subscribe(topic)
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=True)
else:
self._subscribed_topic[topic] = MQTTSubscription(topic, callback, subscribed=False)
def publish(self, topic: str, payload: dict, *, retain: bool) -> None:
if self._connected:
self._mqtt.publish(topic, payload=payload, retain=retain)
else:
self._queued_message.put(MQTTPendingMessage(topic, payload, retain=retain))

View File

@@ -3,16 +3,18 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import BaseModel from pydantic import BaseModel
from src.cloud_util.mqtt import MQTT
from src.recorder.poo import PooRecorder from src.recorder.poo import PooRecorder
recorder = PooRecorder() mqtt = MQTT()
poo_recorder = PooRecorder(mqtt)
@asynccontextmanager @asynccontextmanager
async def _lifespan(_app: FastAPI): # noqa: ANN202 async def _lifespan(_app: FastAPI): # noqa: ANN202
await recorder.start() await mqtt.start()
yield yield
await recorder.stop() await mqtt.stop()
class PooRecordField(BaseModel): class PooRecordField(BaseModel):
@@ -24,6 +26,5 @@ app = FastAPI(lifespan=_lifespan)
@app.post("/poo/record") @app.post("/poo/record")
async def record(record_detail: PooRecordField) -> PooRecordField: async def record(record_detail: PooRecordField) -> PooRecordField:
print(record_detail.status) await poo_recorder.record(record_detail.status)
await recorder.record(record_detail.status)
return record_detail return record_detail

View File

@@ -1,14 +1,11 @@
from datetime import datetime from datetime import datetime
from fastapi_mqtt import FastMQTT, MQTTConfig from src.cloud_util.mqtt import MQTT
from src.recorder.notion_handle import NotionClient from src.recorder.notion_handle import NotionClient
class PooRecorder: class PooRecorder:
mqtt_config = MQTTConfig(username="mqtt", password="mqtt", reconnect_retries=-1) # noqa: S106
notion = NotionClient() notion = NotionClient()
mqtt = FastMQTT(config=mqtt_config, client_id="poo_recorder")
CONFIG_TOPIC = "homeassistant/text/poo_recorder/config" CONFIG_TOPIC = "homeassistant/text/poo_recorder/config"
AVAILABILITY_TOPIC = "studiotj/poo_recorder/status" AVAILABILITY_TOPIC = "studiotj/poo_recorder/status"
COMMAND_TOPIC = "studiotj/poo_recorder/update_text" COMMAND_TOPIC = "studiotj/poo_recorder/update_text"
@@ -17,42 +14,27 @@ class PooRecorder:
ONLINE = "online" ONLINE = "online"
OFFLINE = "offline" OFFLINE = "offline"
def __init__(self) -> None: def __init__(self, mqtt: MQTT) -> None:
print("Initialization.") print("Poo Recorder Initialization...")
self._mqtt = mqtt
async def start(self) -> None: self._mqtt.publish(PooRecorder.CONFIG_TOPIC, PooRecorder.compose_config(), retain=True)
print("Starting...") self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
await PooRecorder.mqtt.mqtt_startup()
async def stop(self) -> None:
print("Stopping...")
await PooRecorder.mqtt.mqtt_shutdown()
async def record(self, status: str) -> None: async def record(self, status: str) -> None:
PooRecorder.publish_text(status) self._publish_text(status)
now = datetime.now(tz=datetime.now().astimezone().tzinfo) now = datetime.now(tz=datetime.now().astimezone().tzinfo)
PooRecorder.publish_time(now) self._publish_time(now)
await PooRecorder.notion.note(now, status) await PooRecorder.notion.note(now, status)
@staticmethod def _publish_text(self, new_text: str) -> None:
@mqtt.on_connect() self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
def on_connect(client, flags, rc, properties) -> None: # noqa: ANN001, ARG004 self._mqtt.publish(PooRecorder.STATE_TOPIC, new_text, retain=True)
print("Connected")
config = PooRecorder.compose_config()
PooRecorder.mqtt.publish(PooRecorder.CONFIG_TOPIC, config, retain=True)
PooRecorder.mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
@staticmethod def _publish_time(self, time: datetime) -> None:
def publish_text(new_text: str) -> None:
PooRecorder.mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
PooRecorder.mqtt.publish(PooRecorder.STATE_TOPIC, new_text, retain=True)
@staticmethod
def publish_time(time: datetime) -> None:
formatted_time = time.strftime("%a | %Y-%m-%d | %H:%M") formatted_time = time.strftime("%a | %Y-%m-%d | %H:%M")
PooRecorder.mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True) self._mqtt.publish(PooRecorder.AVAILABILITY_TOPIC, PooRecorder.ONLINE, retain=True)
json_string = {"last_poo": formatted_time} json_string = {"last_poo": formatted_time}
PooRecorder.mqtt.publish(PooRecorder.JSON_TOPIC, json_string, retain=True) self._mqtt.publish(PooRecorder.JSON_TOPIC, json_string, retain=True)
@staticmethod @staticmethod
def compose_config() -> dict: def compose_config() -> dict:
@@ -60,7 +42,7 @@ class PooRecorder:
"device": { "device": {
"name": "Dog Poop Recorder", "name": "Dog Poop Recorder",
"model": "poop-recorder-backend", "model": "poop-recorder-backend",
"sw_version": "1.2", "sw_version": "1.3",
"identifiers": ["poo_recorder"], "identifiers": ["poo_recorder"],
"manufacturer": "Studio TJ", "manufacturer": "Studio TJ",
}, },