commands and responses now all using decorators to look really nice

This commit is contained in:
Jono Targett 2026-03-15 17:46:05 +10:30
parent bec35f396f
commit b7bd9af807
6 changed files with 253 additions and 146 deletions

View File

@ -1,145 +0,0 @@
#! /usr/bin/env python3
import asyncio
import uuid
import aioserial
import aiomqtt
import pyubx2
import json
from paho.mqtt.properties import Properties
from paho.mqtt.packettypes import PacketTypes
import paho.mqtt.client as mqtt
from command import Command
BAUD = 115200
class SerialMQTTHandler:
def __init__(
self,
handler_id: str,
serial_port: str,
mqtt_host: str = "127.0.0.1",
mqtt_port: int = 1883,
):
self.handler_id = handler_id
self.serial_port = serial_port
self.mqtt_host = mqtt_host
self.mqtt_port = mqtt_port
# Serial
self.ser = aioserial.AioSerial(
port=serial_port,
baudrate=BAUD,
timeout=0.05, # 50 ms
)
# MQTT client
self.client_id = f"{handler_id}-{uuid.uuid4()}"
self.mqtt_client = aiomqtt.Client(
mqtt_host, port=mqtt_port, identifier=self.client_id, protocol=mqtt.MQTTv5
)
# Topic base
self.topic_base = f"asset/{handler_id}"
self.command_topic = f"{self.topic_base}/command"
# Add an arbitrary command
self.commands = {}
example_command = Command(
"example-cmd", {"type": "number"}, description="An example command"
)
self.commands[example_command.name] = example_command
async def parse_serial(self):
buffer = bytearray()
class StreamWrapper:
def read(inner_self, n=1):
if not buffer:
raise BlockingIOError
out = buffer[:n]
del buffer[:n]
return bytes(out)
ubr = pyubx2.UBXReader(StreamWrapper(), parsing=True)
while True:
chunk = await self.ser.read_async(200)
if chunk:
buffer.extend(chunk)
try:
while True:
raw, parsed = ubr.read()
if raw is None:
break
yield raw, parsed
except (pyubx2.UBXStreamError, BlockingIOError, Exception):
pass
else:
await asyncio.sleep(0)
async def serial_reader_task(self):
async for raw, parsed in self.parse_serial():
if isinstance(parsed, pyubx2.UBXMessage):
for name, value in vars(parsed).items():
if name.startswith("_"):
continue
topic = f"{self.topic_base}/{parsed.identity}/{name}"
await self.mqtt_client.publish(topic, value, qos=1, retain=True)
async def mqtt_command_writer_task(self):
async for message in self.mqtt_client.messages:
topic = str(message.topic)
payload = message.payload.decode("utf-8")
payload = json.loads(payload)
if topic.startswith(self.command_topic):
#await self.ser.write_async(message.payload)
command_name = topic.removeprefix(f"{self.command_topic}/")
command = self.commands.get(command_name)
if command is not None:
print(topic, payload, "valid:", command.validate(payload), message.properties)
if message.properties is not None and message.properties.ResponseTopic is not None:
await self.mqtt_client.publish(message.properties.ResponseTopic, message.payload, qos=1)
else:
print("Unknown command:", topic, message.payload)
async def run(self):
interval = 5
while True:
try:
async with self.mqtt_client as client:
await client.subscribe(f"{self.command_topic}/+")
for command in self.commands.values():
props = Properties(PacketTypes.PUBLISH)
props.ResponseTopic = "asset/client/response"
props.CorrelationData = b"req-42"
await self.mqtt_client.publish(f"{self.command_topic}/{command.name}/schema", str(command.schema), qos=1, retain=True, properties=props)
for k,v in command.additional_properties.items():
await self.mqtt_client.publish(f"{self.command_topic}/{command.name}/{k}", str(v), qos=1, retain=True)
await asyncio.gather(
self.serial_reader_task(),
self.mqtt_command_writer_task(),
)
except aiomqtt.MqttError as e:
print(
f"[{self.handler_id}] MQTT connection error: {e}. Reconnecting in {interval}s..."
)
await asyncio.sleep(interval)
if __name__ == "__main__":
async def main():
handler = SerialMQTTHandler("example-gps", "/tmp/ttyV0")
await handler.run()
asyncio.run(main())

View File

@ -1,10 +1,12 @@
import jsonschema import jsonschema
from dataclasses import dataclass
class Command: class Command:
def __init__(self, name: str, schema, **kwargs): def __init__(self, name: str, schema, handler=None, **kwargs):
self.name = name self.name = name
self.schema = schema self.schema = schema
self.handler = handler
self.additional_properties = kwargs self.additional_properties = kwargs
self._validator = jsonschema.validators.validator_for( self._validator = jsonschema.validators.validator_for(
@ -13,3 +15,24 @@ class Command:
def validate(self, o) -> bool: def validate(self, o) -> bool:
return self._validator.is_valid(o) return self._validator.is_valid(o)
def __call__(self, args):
if not self.validate(args):
raise ValueError(f"Invalid arguments for command '{self.name}'")
if self.handler is None:
raise RuntimeError(f"No handler bound for command '{self.name}'")
return self.handler(args)
def command(schema, **kwargs):
def decorator(func):
return Command(func.__name__, schema, handler=func, **kwargs)
return decorator
@dataclass
class CommandResponse:
success: bool
message: str = None
correlation: str = None

130
handler.py Normal file
View File

@ -0,0 +1,130 @@
import asyncio
import aiomqtt
import json
from command import Command, CommandResponse
import inspect
from paho.mqtt.properties import Properties
from dataclasses import asdict
BAUD = 115200
INTERVAL = 5
class MQTTHandler:
def __init__(
self,
mqtt_client: aiomqtt.Client,
handler_id: str,
):
self.handler_id = handler_id
self.mqtt_client = mqtt_client
self.topic_base = f"asset/{handler_id}"
self.command_topic = f"{self.topic_base}/command"
self.property_topic = f"{self.topic_base}/property"
def get_available_commands(self):
commands = {}
for base in self.__class__.__mro__:
for name, attr in vars(base).items():
if isinstance(attr, Command):
commands[name] = attr
return commands
async def publish_commands(self):
for name, command in self.get_available_commands().items():
await self.mqtt_client.publish(
f"{self.command_topic}/{command.name}/schema",
str(command.schema),
qos=1,
retain=True,
)
for k, v in command.additional_properties.items():
await self.mqtt_client.publish(
f"{self.command_topic}/{command.name}/{k}",
str(v),
qos=1,
retain=True,
)
async def execute_command(
self, command_name: str, payload: str, properties: Properties = None
):
async def respond(success: bool, message: str = None):
if (
properties is not None
and hasattr(properties, "ResponseTopic")
and properties.ResponseTopic is not None
):
correlation = (
properties.CorrelationData.decode("utf-8")
if hasattr(properties, "CorrelationData")
else None
)
await self.mqtt_client.publish(
properties.ResponseTopic,
json.dumps(
asdict(CommandResponse(success, str(message), correlation))
),
qos=1,
retain=False,
)
try:
command = self.get_available_commands()[command_name]
argument = json.loads(payload)
result = await command(argument)
await respond(True, result)
except json.decoder.JSONDecodeError as e:
await respond(False, f"Failed to parse payload as JSON: {e}")
except ValueError as e:
await respond(False, f"Command payload does not match expected schema: {e}")
except Exception as e:
print(f"Failed to execute command {command_name} with unknown cause: ", e)
await respond(False, "Unexpected error")
async def mqtt_command_writer_task(self):
async for message in self.mqtt_client.messages:
topic = str(message.topic)
payload = message.payload.decode("utf-8")
if topic.startswith(self.command_topic):
command_name = topic.removeprefix(f"{self.command_topic}/")
await self.execute_command(command_name, payload, message.properties)
async def run(self):
while True:
try:
async with self.mqtt_client as client:
await client.subscribe(f"{self.command_topic}/+")
await self.publish_commands()
tasks = [self.mqtt_command_writer_task()]
# Inspect instance methods
for attr_name in dir(self):
attr = getattr(self, attr_name)
if callable(attr) and getattr(attr, "_is_task", False):
if not inspect.iscoroutinefunction(attr):
raise TypeError(
f"@task can only decorate async methods: {attr_name}"
)
tasks.append(attr()) # call it, returns coroutine
await asyncio.gather(*tasks)
except aiomqtt.MqttError as e:
print(
f"[{self.handler_id}] MQTT connection error: {e}. Reconnecting in {INTERVAL}s..."
)
await asyncio.sleep(INTERVAL)
def task(func):
"""Decorator to mark async methods for automatic gathering."""
func._is_task = True
return func

37
main.py Executable file
View File

@ -0,0 +1,37 @@
#! /usr/bin/env python3
import aiomqtt
import aioserial
import asyncio
import paho
import uuid
from ubxhandler import UBXHandler
BAUD = 115200
async def main():
handler_id = "example-gps"
mqtt_host = "127.0.0.1"
mqtt_port = 1883
client_id = f"{handler_id}-{uuid.uuid4()}"
mqtt_client = aiomqtt.Client(
mqtt_host,
port=mqtt_port,
identifier=client_id,
protocol=paho.mqtt.client.MQTTv5,
)
serial_port = aioserial.AioSerial(
port="/tmp/ttyV0",
baudrate=BAUD,
timeout=0.05, # 50 ms
)
handler = UBXHandler(mqtt_client, "example-gps", serial_port)
await handler.run()
if __name__ == "__main__":
asyncio.run(main())

62
ubxhandler.py Normal file
View File

@ -0,0 +1,62 @@
import asyncio
import aioserial
import aiomqtt
import pyubx2
from command import command
from handler import MQTTHandler, task
class UBXHandler(MQTTHandler):
def __init__(
self,
mqtt_client: aiomqtt.Client,
handler_id: str,
serial_port: aioserial.AioSerial
):
super().__init__(mqtt_client, handler_id)
self.serial_port = serial_port
@command({"type": "number"}, description="An example command")
async def example_cmd(args):
print(f"Executing command with args {args}")
async def parse_serial(self):
buffer = bytearray()
class StreamWrapper:
def read(inner_self, n=1):
if not buffer:
raise BlockingIOError
out = buffer[:n]
del buffer[:n]
return bytes(out)
ubr = pyubx2.UBXReader(StreamWrapper(), parsing=True)
while True:
chunk = await self.serial_port.read_async(200)
if chunk:
buffer.extend(chunk)
try:
while True:
raw, parsed = ubr.read()
if raw is None:
break
yield raw, parsed
except (pyubx2.UBXStreamError, BlockingIOError, Exception):
pass
else:
await asyncio.sleep(0)
@task
async def serial_reader_task(self):
async for raw, parsed in self.parse_serial():
if isinstance(parsed, pyubx2.UBXMessage):
for name, value in vars(parsed).items():
if name.startswith("_"):
continue
topic = f"{self.topic_base}/{parsed.identity}/{name}"
await self.mqtt_client.publish(topic, value, qos=1, retain=True)