222 lines
7.1 KiB
Python
222 lines
7.1 KiB
Python
import aiomqtt
|
|
import asyncio
|
|
import inspect
|
|
import paho
|
|
import signal
|
|
import json
|
|
from enum import Enum, auto
|
|
|
|
from .command import (
|
|
Command,
|
|
CommandResponse,
|
|
CommandArgumentError,
|
|
CommandExecutionError,
|
|
enumerate_commands,
|
|
)
|
|
|
|
from .property import Property
|
|
|
|
|
|
class Status(Enum):
|
|
ONLINE = auto()
|
|
OFFLINE = auto()
|
|
|
|
|
|
class MQTTHandler:
|
|
DEVICE = "device"
|
|
META = "meta"
|
|
PROPERTY = "property"
|
|
COMMAND = "command"
|
|
STATUS = "status"
|
|
|
|
def __init__(
|
|
self,
|
|
handler_id: str,
|
|
):
|
|
self.handler_id = handler_id
|
|
|
|
self.topic_base = f"{MQTTHandler.DEVICE}/{handler_id}"
|
|
self.meta_topic = f"{self.topic_base}/{MQTTHandler.META}"
|
|
self.command_topic = f"{self.topic_base}/{MQTTHandler.COMMAND}"
|
|
self.property_topic = f"{self.topic_base}/{MQTTHandler.PROPERTY}"
|
|
|
|
self._shutdown_event = asyncio.Event()
|
|
|
|
self._mqtt_client = None
|
|
|
|
self._commands = enumerate_commands(self)
|
|
self._properties = {}
|
|
self._meta = {}
|
|
|
|
async def set_property(self, name: str, value, **kwargs):
|
|
if name in self._properties:
|
|
await self._properties[name](value, **kwargs)
|
|
else:
|
|
#print(f"Warning: proeprty {name} is unregistered")
|
|
await self._publish(f"{self.property_topic}/{name}", value, **kwargs)
|
|
|
|
async def register_property(
|
|
self, name: str, description: str | None = None, schema: dict | None = None
|
|
):
|
|
property = self._register_property(
|
|
f"{self.property_topic}/{name}", description, schema
|
|
)
|
|
self._properties[name] = property
|
|
|
|
async def _register_property(
|
|
self, name: str, description: str | None = None, schema: dict | None = None
|
|
):
|
|
property = Property(name, description, schema, self._publish)
|
|
data = {
|
|
"schema": json.dumps(schema),
|
|
"description": description,
|
|
}
|
|
for k, v in {k: v for k, v in data.items() if v is not None}.items():
|
|
await self._mqtt_client.publish(
|
|
f"{name}/${k}",
|
|
str(v),
|
|
qos=1,
|
|
retain=True,
|
|
)
|
|
|
|
return property
|
|
|
|
async def _publish(self, name: str, value, **kwargs):
|
|
await self._mqtt_client.publish(f"{name}", value, **kwargs)
|
|
|
|
async def _register_commands(self):
|
|
for name, command in self._commands.items():
|
|
for k, v in {
|
|
"schema": json.dumps(command.schema),
|
|
"description": command.description,
|
|
**command.additional_properties,
|
|
}.items():
|
|
await self._mqtt_client.publish(
|
|
f"{self.command_topic}/{command.name}/${k}",
|
|
str(v),
|
|
qos=1,
|
|
retain=True,
|
|
)
|
|
|
|
async def _announce(self):
|
|
# announce that we are online
|
|
await self._register_commands()
|
|
|
|
self._meta[MQTTHandler.STATUS] = await self._register_property(
|
|
f"{self.meta_topic}/{MQTTHandler.STATUS}",
|
|
"Indicates the status of the device.",
|
|
{"type": "string", "enum": list(Status.__members__.keys())},
|
|
)
|
|
await self._meta[MQTTHandler.STATUS](
|
|
self, json.dumps(Status.ONLINE.name), qos=1, retain=True
|
|
)
|
|
|
|
async def _execute_command(
|
|
self,
|
|
command_name: str,
|
|
payload: str,
|
|
properties: paho.mqtt.properties.Properties = None,
|
|
):
|
|
async def respond(success: bool, message: str = None):
|
|
if (
|
|
properties is not None
|
|
and hasattr(properties, "ResponseTopic")
|
|
and properties.ResponseTopic is not None
|
|
):
|
|
correlation = (
|
|
properties.CorrelationData.decode("utf-8")
|
|
if hasattr(properties, "CorrelationData")
|
|
else None
|
|
)
|
|
await self._mqtt_client.publish(
|
|
properties.ResponseTopic,
|
|
str(CommandResponse(success, str(message), correlation)),
|
|
qos=1,
|
|
retain=False,
|
|
)
|
|
|
|
try:
|
|
command = self._commands[command_name]
|
|
result = await command(self, payload)
|
|
await respond(True, result)
|
|
except (CommandArgumentError, CommandExecutionError) as e:
|
|
await respond(False, f"{e}")
|
|
except Exception as e:
|
|
print(f"Failed to execute command {command_name} with unknown cause: ", e)
|
|
await respond(False, "Unexpected error")
|
|
|
|
async def _command_executor(self):
|
|
await self._mqtt_client.subscribe(f"{self.command_topic}/+")
|
|
|
|
async for message in self._mqtt_client.messages:
|
|
topic = str(message.topic)
|
|
payload = message.payload.decode("utf-8")
|
|
|
|
if topic.startswith(self.command_topic):
|
|
command_name = topic.removeprefix(f"{self.command_topic}/")
|
|
await self._execute_command(command_name, payload, message.properties)
|
|
|
|
async def _shutdown_watcher(self):
|
|
await self._shutdown_event.wait()
|
|
await self._meta[MQTTHandler.STATUS](
|
|
self, json.dumps(Status.OFFLINE.name), qos=1, retain=True
|
|
)
|
|
|
|
def stop(self):
|
|
self._shutdown_event.set()
|
|
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
|
|
|
async def run(self, host: str, **kwargs):
|
|
INTERVAL = 5
|
|
|
|
will = aiomqtt.Will(
|
|
topic=f"{self.meta_topic}/{MQTTHandler.STATUS}",
|
|
payload=json.dumps(Status.OFFLINE.name),
|
|
qos=1,
|
|
retain=True,
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
async with aiomqtt.Client(
|
|
host,
|
|
protocol=paho.mqtt.client.MQTTv5,
|
|
will=will,
|
|
identifier=self.handler_id,
|
|
**kwargs,
|
|
) as client:
|
|
self._mqtt_client = client
|
|
|
|
tasks = [
|
|
self._command_executor(),
|
|
self._shutdown_watcher(),
|
|
self._announce(),
|
|
]
|
|
|
|
# Inspect instance methods
|
|
for attr_name in dir(self):
|
|
attr = getattr(self, attr_name)
|
|
if callable(attr) and getattr(attr, "_is_task", False):
|
|
if not inspect.iscoroutinefunction(attr):
|
|
raise TypeError(
|
|
f"@task can only decorate async methods: {attr_name}"
|
|
)
|
|
tasks.append(attr()) # call it, returns coroutine
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
except aiomqtt.MqttError as e:
|
|
print(
|
|
f"[{self.handler_id}] MQTT connection error: {e}. Reconnecting in {INTERVAL}s..."
|
|
)
|
|
await asyncio.sleep(INTERVAL)
|
|
|
|
finally:
|
|
self._mqtt_client = None
|
|
|
|
|
|
def task(func):
|
|
"""Decorator to mark async methods for automatic gathering."""
|
|
func._is_task = True
|
|
return func
|