import aiomqtt import asyncio import inspect import paho import signal import json from dataclasses import dataclass from command import ( Command, CommandResponse, CommandArgumentError, CommandExecutionError, ) @dataclass class MQTTConfig: host: str port: int = 1883 username: str | None = None password: str | None = None keepalive: int = 60 class MQTTHandler: def __init__( self, mqtt_config: MQTTConfig, handler_id: str, ): self.handler_id = handler_id self.mqtt_config = mqtt_config self.topic_base = f"device/{handler_id}" self.command_topic = f"{self.topic_base}/command" self.property_topic = f"{self.topic_base}/property" self.status_topic = f"{self.property_topic}/status" self._shutdown_event = asyncio.Event() will = aiomqtt.Will( topic=self.status_topic, payload="OFFLINE", qos=1, retain=True ) self.mqtt_client = aiomqtt.Client( self.mqtt_config.host, port=self.mqtt_config.port, identifier=handler_id, protocol=paho.mqtt.client.MQTTv5, will=will, ) def get_available_commands(self): commands = {} for base in self.__class__.__mro__: for name, attr in vars(base).items(): if isinstance(attr, Command): commands[name] = attr return commands async def publish_commands(self): for name, command in self.get_available_commands().items(): await self.mqtt_client.publish( f"{self.command_topic}/{command.name}/schema", json.dumps(command.schema), qos=1, retain=True, ) for k, v in command.additional_properties.items(): await self.mqtt_client.publish( f"{self.command_topic}/{command.name}/{k}", str(v), qos=1, retain=True, ) async def execute_command( self, command_name: str, payload: str, properties: paho.mqtt.properties.Properties = None, ): async def respond(success: bool, message: str = None): if ( properties is not None and hasattr(properties, "ResponseTopic") and properties.ResponseTopic is not None ): correlation = ( properties.CorrelationData.decode("utf-8") if hasattr(properties, "CorrelationData") else None ) await self.mqtt_client.publish( properties.ResponseTopic, str(CommandResponse(success, str(message), correlation)), qos=1, retain=False, ) try: command = self.get_available_commands()[command_name] result = await command(self, payload) await respond(True, result) except (CommandArgumentError, CommandExecutionError) as e: await respond(False, f"{e}") except Exception as e: print(f"Failed to execute command {command_name} with unknown cause: ", e) await respond(False, "Unexpected error") async def mqtt_command_writer_task(self): async for message in self.mqtt_client.messages: topic = str(message.topic) payload = message.payload.decode("utf-8") if topic.startswith(self.command_topic): command_name = topic.removeprefix(f"{self.command_topic}/") await self.execute_command(command_name, payload, message.properties) async def shutdown_watcher(self): await self._shutdown_event.wait() await self.mqtt_client.publish(self.status_topic, "OFFLINE", qos=1, retain=True) def stop(self): self._shutdown_event.set() signal.signal(signal.SIGINT, signal.SIG_DFL) async def run(self): INTERVAL = 5 while True: try: async with self.mqtt_client as client: await client.subscribe(f"{self.command_topic}/+") await self.publish_commands() # announce that we are online await client.publish( self.status_topic, "ONLINE", qos=1, retain=True ) tasks = [self.mqtt_command_writer_task(), self.shutdown_watcher()] # Inspect instance methods for attr_name in dir(self): attr = getattr(self, attr_name) if callable(attr) and getattr(attr, "_is_task", False): if not inspect.iscoroutinefunction(attr): raise TypeError( f"@task can only decorate async methods: {attr_name}" ) tasks.append(attr()) # call it, returns coroutine await asyncio.gather(*tasks) except aiomqtt.MqttError as e: print( f"[{self.handler_id}] MQTT connection error: {e}. Reconnecting in {INTERVAL}s..." ) await asyncio.sleep(INTERVAL) def task(func): """Decorator to mark async methods for automatic gathering.""" func._is_task = True return func