131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
import asyncio
|
|
import aiomqtt
|
|
import json
|
|
from command import Command, CommandResponse
|
|
import inspect
|
|
from paho.mqtt.properties import Properties
|
|
from dataclasses import asdict
|
|
|
|
BAUD = 115200
|
|
INTERVAL = 5
|
|
|
|
|
|
class MQTTHandler:
|
|
|
|
def __init__(
|
|
self,
|
|
mqtt_client: aiomqtt.Client,
|
|
handler_id: str,
|
|
):
|
|
self.handler_id = handler_id
|
|
self.mqtt_client = mqtt_client
|
|
|
|
self.topic_base = f"asset/{handler_id}"
|
|
self.command_topic = f"{self.topic_base}/command"
|
|
self.property_topic = f"{self.topic_base}/property"
|
|
|
|
def get_available_commands(self):
|
|
commands = {}
|
|
for base in self.__class__.__mro__:
|
|
for name, attr in vars(base).items():
|
|
if isinstance(attr, Command):
|
|
commands[name] = attr
|
|
|
|
return commands
|
|
|
|
async def publish_commands(self):
|
|
for name, command in self.get_available_commands().items():
|
|
await self.mqtt_client.publish(
|
|
f"{self.command_topic}/{command.name}/schema",
|
|
str(command.schema),
|
|
qos=1,
|
|
retain=True,
|
|
)
|
|
for k, v in command.additional_properties.items():
|
|
await self.mqtt_client.publish(
|
|
f"{self.command_topic}/{command.name}/{k}",
|
|
str(v),
|
|
qos=1,
|
|
retain=True,
|
|
)
|
|
|
|
async def execute_command(
|
|
self, command_name: str, payload: str, properties: Properties = None
|
|
):
|
|
|
|
async def respond(success: bool, message: str = None):
|
|
if (
|
|
properties is not None
|
|
and hasattr(properties, "ResponseTopic")
|
|
and properties.ResponseTopic is not None
|
|
):
|
|
correlation = (
|
|
properties.CorrelationData.decode("utf-8")
|
|
if hasattr(properties, "CorrelationData")
|
|
else None
|
|
)
|
|
await self.mqtt_client.publish(
|
|
properties.ResponseTopic,
|
|
json.dumps(
|
|
asdict(CommandResponse(success, str(message), correlation))
|
|
),
|
|
qos=1,
|
|
retain=False,
|
|
)
|
|
|
|
try:
|
|
command = self.get_available_commands()[command_name]
|
|
argument = json.loads(payload)
|
|
result = await command(self, argument)
|
|
await respond(True, result)
|
|
|
|
except json.decoder.JSONDecodeError as e:
|
|
await respond(False, f"Failed to parse payload as JSON: {e}")
|
|
except ValueError as e:
|
|
await respond(False, f"Command payload does not match expected schema: {e}")
|
|
except Exception as e:
|
|
print(f"Failed to execute command {command_name} with unknown cause: ", e)
|
|
await respond(False, "Unexpected error")
|
|
|
|
async def mqtt_command_writer_task(self):
|
|
async for message in self.mqtt_client.messages:
|
|
topic = str(message.topic)
|
|
payload = message.payload.decode("utf-8")
|
|
|
|
if topic.startswith(self.command_topic):
|
|
command_name = topic.removeprefix(f"{self.command_topic}/")
|
|
await self.execute_command(command_name, payload, message.properties)
|
|
|
|
async def run(self):
|
|
while True:
|
|
try:
|
|
async with self.mqtt_client as client:
|
|
await client.subscribe(f"{self.command_topic}/+")
|
|
await self.publish_commands()
|
|
|
|
tasks = [self.mqtt_command_writer_task()]
|
|
|
|
# Inspect instance methods
|
|
for attr_name in dir(self):
|
|
attr = getattr(self, attr_name)
|
|
if callable(attr) and getattr(attr, "_is_task", False):
|
|
if not inspect.iscoroutinefunction(attr):
|
|
raise TypeError(
|
|
f"@task can only decorate async methods: {attr_name}"
|
|
)
|
|
tasks.append(attr()) # call it, returns coroutine
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
except aiomqtt.MqttError as e:
|
|
print(
|
|
f"[{self.handler_id}] MQTT connection error: {e}. Reconnecting in {INTERVAL}s..."
|
|
)
|
|
await asyncio.sleep(INTERVAL)
|
|
|
|
|
|
def task(func):
|
|
"""Decorator to mark async methods for automatic gathering."""
|
|
func._is_task = True
|
|
return func
|