mqttdevicemanager/mqtthandler/handler.py
2026-03-20 00:05:05 +10:30

253 lines
8.2 KiB
Python

import aiomqtt
import asyncio
import inspect
import paho
import signal
import json
import secrets
import os
import socket
from pathlib import Path
from enum import Enum, auto
from .command import (
Command,
CommandResponse,
CommandArgumentError,
CommandExecutionError,
enumerate_commands,
)
from .property import Property
def get_identifier(cache_path: Path) -> str:
"""
Determine an MQTT client ID using the following order:
1. Environment variable IDENTIFIER
2. Value stored in /tmp/<handler-name>.tmp
3. Generate a new random ID using secrets.token_urlsafe
The resulting client ID is written to /tmp/mqtt_client_id.tmp for future use.
"""
client_id = os.environ.get("IDENTIFIER", None)
if not client_id and cache_path.exists():
client_id = cache_path.read_text().strip()
elif not client_id:
client_id = generate_identifier()
cache_path.write_text(client_id)
return client_id
def generate_identifier() -> str:
return secrets.token_urlsafe(6)
class Status(Enum):
ONLINE = auto()
OFFLINE = auto()
class MQTTHandler:
DEVICE = "device"
META = "meta"
PROPERTY = "property"
COMMAND = "command"
STATUS = "status"
def __init__(self, name: str):
self.name = name
self.identifier = get_identifier(Path(f"/tmp/{self.name}.tmp"))
self.topic_base = lambda: f"{MQTTHandler.DEVICE}/{self.identifier}"
self.meta_topic = lambda: f"{self.topic_base()}/{MQTTHandler.META}"
self.command_topic = lambda: f"{self.topic_base()}/{MQTTHandler.COMMAND}"
self.property_topic = lambda: f"{self.topic_base()}/{MQTTHandler.PROPERTY}"
self._shutdown_event = asyncio.Event()
self._mqtt_client = None
self._commands = enumerate_commands(self)
self._properties = {}
self._meta = {}
async def set_property(self, name: str, value, **kwargs):
if name in self._properties:
await self._properties[name](value, **kwargs)
else:
# print(f"Warning: proeprty {name} is unregistered")
await self._publish(f"{self.property_topic()}/{name}", value, **kwargs)
async def register_property(
self, name: str, description: str | None = None, schema: dict | None = None
):
property = self._register_property(
f"{self.property_topic()}/{name}", description, schema
)
self._properties[name] = property
async def _register_property(
self, name: str, description: str | None = None, schema: dict | None = None
):
property = Property(name, description, schema, self._publish)
data = {
"schema": json.dumps(schema),
"description": description,
}
for k, v in {k: v for k, v in data.items() if v is not None}.items():
await self._mqtt_client.publish(
f"{name}/${k}",
str(v),
qos=1,
retain=True,
)
return property
async def _publish(self, name: str, value, **kwargs):
await self._mqtt_client.publish(f"{name}", value, **kwargs)
async def _register_commands(self):
for name, command in self._commands.items():
for k, v in {
"schema": json.dumps(command.schema),
"description": command.description,
**command.additional_properties,
}.items():
await self._mqtt_client.publish(
f"{self.command_topic()}/{command.name}/${k}",
str(v),
qos=1,
retain=True,
)
async def _announce(self):
# announce that we are online
await self._register_commands()
self._meta[MQTTHandler.STATUS] = await self._register_property(
f"{self.meta_topic()}/{MQTTHandler.STATUS}",
"Indicates the status of the device.",
{"type": "string", "enum": list(Status.__members__.keys())},
)
await self._meta[MQTTHandler.STATUS](
self, json.dumps(Status.ONLINE.name), qos=1, retain=True
)
await self._publish(f"{self.meta_topic()}/name", self.name, qos=1, retain=True)
await self._publish(
f"{self.meta_topic()}/type", type(self).__name__, qos=1, retain=True
)
await self._publish(
f"{self.meta_topic()}/host", socket.gethostname(), qos=1, retain=True
)
async def _execute_command(
self,
command_name: str,
payload: str,
properties: paho.mqtt.properties.Properties = None,
):
async def respond(success: bool, message: str = None):
if (
properties is not None
and hasattr(properties, "ResponseTopic")
and properties.ResponseTopic is not None
):
correlation = (
properties.CorrelationData.decode("utf-8")
if hasattr(properties, "CorrelationData")
else None
)
await self._mqtt_client.publish(
properties.ResponseTopic,
str(CommandResponse(success, str(message), correlation)),
qos=1,
retain=False,
)
try:
command = self._commands[command_name]
result = await command(self, payload)
await respond(True, result)
except (CommandArgumentError, CommandExecutionError) as e:
await respond(False, f"{e}")
except Exception as e:
print(f"Failed to execute command {command_name} with unknown cause: ", e)
await respond(False, "Unexpected error")
async def _command_executor(self):
await self._mqtt_client.subscribe(f"{self.command_topic()}/+")
async for message in self._mqtt_client.messages:
topic = str(message.topic)
payload = message.payload.decode("utf-8")
if topic.startswith(self.command_topic()):
command_name = topic.removeprefix(f"{self.command_topic()}/")
await self._execute_command(command_name, payload, message.properties)
async def _shutdown_watcher(self):
await self._shutdown_event.wait()
await self._meta[MQTTHandler.STATUS](
self, json.dumps(Status.OFFLINE.name), qos=1, retain=True
)
def stop(self):
self._shutdown_event.set()
signal.signal(signal.SIGINT, signal.SIG_DFL)
async def run(self, host: str, **kwargs):
INTERVAL = 5
will = aiomqtt.Will(
topic=f"{self.meta_topic()}/{MQTTHandler.STATUS}",
payload=json.dumps(Status.OFFLINE.name),
qos=1,
retain=True,
)
while True:
try:
async with aiomqtt.Client(
host,
protocol=paho.mqtt.client.MQTTv5,
will=will,
identifier=self.identifier,
**kwargs,
) as client:
self._mqtt_client = client
tasks = [
self._command_executor(),
self._shutdown_watcher(),
self._announce(),
]
# Inspect instance methods
for attr_name in dir(self):
attr = getattr(self, attr_name)
if callable(attr) and getattr(attr, "_is_task", False):
if not inspect.iscoroutinefunction(attr):
raise TypeError(
f"@task can only decorate async methods: {attr_name}"
)
tasks.append(attr()) # call it, returns coroutine
await asyncio.gather(*tasks)
except aiomqtt.MqttError as e:
print(f"MQTT connection error: {e}. Reconnecting in {INTERVAL}s...")
await asyncio.sleep(INTERVAL)
finally:
self._mqtt_client = None
def task(func):
"""Decorator to mark async methods for automatic gathering."""
func._is_task = True
return func