mqttdevicemanager/handler.py

167 lines
5.4 KiB
Python

import aiomqtt
import asyncio
import inspect
import paho
import signal
from dataclasses import dataclass
from command import (
Command,
CommandResponse,
CommandArgumentError,
CommandExecutionError,
)
@dataclass
class MQTTConfig:
host: str
port: int = 1883
username: str | None = None
password: str | None = None
keepalive: int = 60
class MQTTHandler:
def __init__(
self,
mqtt_config: MQTTConfig,
handler_id: str,
):
self.handler_id = handler_id
self.mqtt_config = mqtt_config
self.topic_base = f"device/{handler_id}"
self.status_topic = f"{self.topic_base}/status"
self.command_topic = f"{self.topic_base}/command"
self.property_topic = f"{self.topic_base}/property"
self._shutdown_event = asyncio.Event()
will = aiomqtt.Will(
topic=self.status_topic, payload="OFFLINE", qos=1, retain=True
)
self.mqtt_client = aiomqtt.Client(
self.mqtt_config.host,
port=self.mqtt_config.port,
identifier=handler_id,
protocol=paho.mqtt.client.MQTTv5,
will=will,
)
def get_available_commands(self):
commands = {}
for base in self.__class__.__mro__:
for name, attr in vars(base).items():
if isinstance(attr, Command):
commands[name] = attr
return commands
async def publish_commands(self):
for name, command in self.get_available_commands().items():
await self.mqtt_client.publish(
f"{self.command_topic}/{command.name}/schema",
str(command.schema),
qos=1,
retain=True,
)
for k, v in command.additional_properties.items():
await self.mqtt_client.publish(
f"{self.command_topic}/{command.name}/{k}",
str(v),
qos=1,
retain=True,
)
async def execute_command(
self,
command_name: str,
payload: str,
properties: paho.mqtt.properties.Properties = None,
):
async def respond(success: bool, message: str = None):
if (
properties is not None
and hasattr(properties, "ResponseTopic")
and properties.ResponseTopic is not None
):
correlation = (
properties.CorrelationData.decode("utf-8")
if hasattr(properties, "CorrelationData")
else None
)
await self.mqtt_client.publish(
properties.ResponseTopic,
str(CommandResponse(success, str(message), correlation)),
qos=1,
retain=False,
)
try:
command = self.get_available_commands()[command_name]
result = await command(self, payload)
await respond(True, result)
except (CommandArgumentError, CommandExecutionError) as e:
await respond(False, f"{e}")
except Exception as e:
print(f"Failed to execute command {command_name} with unknown cause: ", e)
await respond(False, "Unexpected error")
async def mqtt_command_writer_task(self):
async for message in self.mqtt_client.messages:
topic = str(message.topic)
payload = message.payload.decode("utf-8")
if topic.startswith(self.command_topic):
command_name = topic.removeprefix(f"{self.command_topic}/")
await self.execute_command(command_name, payload, message.properties)
async def shutdown_watcher(self):
await self._shutdown_event.wait()
await self.mqtt_client.publish(self.status_topic, "OFFLINE", qos=1, retain=True)
def stop(self):
self._shutdown_event.set()
signal.signal(signal.SIGINT, signal.SIG_DFL)
async def run(self):
INTERVAL = 5
while True:
try:
async with self.mqtt_client as client:
await client.subscribe(f"{self.command_topic}/+")
await self.publish_commands()
# announce that we are online
await client.publish(
self.status_topic, "ONLINE", qos=1, retain=True
)
tasks = [self.mqtt_command_writer_task(), self.shutdown_watcher()]
# Inspect instance methods
for attr_name in dir(self):
attr = getattr(self, attr_name)
if callable(attr) and getattr(attr, "_is_task", False):
if not inspect.iscoroutinefunction(attr):
raise TypeError(
f"@task can only decorate async methods: {attr_name}"
)
tasks.append(attr()) # call it, returns coroutine
await asyncio.gather(*tasks)
except aiomqtt.MqttError as e:
print(
f"[{self.handler_id}] MQTT connection error: {e}. Reconnecting in {INTERVAL}s..."
)
await asyncio.sleep(INTERVAL)
def task(func):
"""Decorator to mark async methods for automatic gathering."""
func._is_task = True
return func