commands and responses now all using decorators to look really nice
This commit is contained in:
parent
bec35f396f
commit
b7bd9af807
145
assethandler.py
145
assethandler.py
@ -1,145 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
import asyncio
|
||||
import uuid
|
||||
import aioserial
|
||||
import aiomqtt
|
||||
import pyubx2
|
||||
import json
|
||||
from paho.mqtt.properties import Properties
|
||||
from paho.mqtt.packettypes import PacketTypes
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
from command import Command
|
||||
|
||||
BAUD = 115200
|
||||
|
||||
|
||||
class SerialMQTTHandler:
|
||||
def __init__(
|
||||
self,
|
||||
handler_id: str,
|
||||
serial_port: str,
|
||||
mqtt_host: str = "127.0.0.1",
|
||||
mqtt_port: int = 1883,
|
||||
):
|
||||
self.handler_id = handler_id
|
||||
self.serial_port = serial_port
|
||||
self.mqtt_host = mqtt_host
|
||||
self.mqtt_port = mqtt_port
|
||||
|
||||
# Serial
|
||||
self.ser = aioserial.AioSerial(
|
||||
port=serial_port,
|
||||
baudrate=BAUD,
|
||||
timeout=0.05, # 50 ms
|
||||
)
|
||||
|
||||
# MQTT client
|
||||
self.client_id = f"{handler_id}-{uuid.uuid4()}"
|
||||
self.mqtt_client = aiomqtt.Client(
|
||||
mqtt_host, port=mqtt_port, identifier=self.client_id, protocol=mqtt.MQTTv5
|
||||
)
|
||||
|
||||
# Topic base
|
||||
self.topic_base = f"asset/{handler_id}"
|
||||
self.command_topic = f"{self.topic_base}/command"
|
||||
|
||||
# Add an arbitrary command
|
||||
self.commands = {}
|
||||
|
||||
example_command = Command(
|
||||
"example-cmd", {"type": "number"}, description="An example command"
|
||||
)
|
||||
self.commands[example_command.name] = example_command
|
||||
|
||||
async def parse_serial(self):
|
||||
buffer = bytearray()
|
||||
|
||||
class StreamWrapper:
|
||||
def read(inner_self, n=1):
|
||||
if not buffer:
|
||||
raise BlockingIOError
|
||||
out = buffer[:n]
|
||||
del buffer[:n]
|
||||
return bytes(out)
|
||||
|
||||
ubr = pyubx2.UBXReader(StreamWrapper(), parsing=True)
|
||||
|
||||
while True:
|
||||
chunk = await self.ser.read_async(200)
|
||||
if chunk:
|
||||
buffer.extend(chunk)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw, parsed = ubr.read()
|
||||
if raw is None:
|
||||
break
|
||||
yield raw, parsed
|
||||
except (pyubx2.UBXStreamError, BlockingIOError, Exception):
|
||||
pass
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def serial_reader_task(self):
|
||||
async for raw, parsed in self.parse_serial():
|
||||
if isinstance(parsed, pyubx2.UBXMessage):
|
||||
for name, value in vars(parsed).items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
topic = f"{self.topic_base}/{parsed.identity}/{name}"
|
||||
await self.mqtt_client.publish(topic, value, qos=1, retain=True)
|
||||
|
||||
async def mqtt_command_writer_task(self):
|
||||
async for message in self.mqtt_client.messages:
|
||||
topic = str(message.topic)
|
||||
payload = message.payload.decode("utf-8")
|
||||
payload = json.loads(payload)
|
||||
if topic.startswith(self.command_topic):
|
||||
#await self.ser.write_async(message.payload)
|
||||
command_name = topic.removeprefix(f"{self.command_topic}/")
|
||||
command = self.commands.get(command_name)
|
||||
if command is not None:
|
||||
print(topic, payload, "valid:", command.validate(payload), message.properties)
|
||||
|
||||
if message.properties is not None and message.properties.ResponseTopic is not None:
|
||||
await self.mqtt_client.publish(message.properties.ResponseTopic, message.payload, qos=1)
|
||||
else:
|
||||
print("Unknown command:", topic, message.payload)
|
||||
|
||||
async def run(self):
|
||||
interval = 5
|
||||
|
||||
while True:
|
||||
try:
|
||||
async with self.mqtt_client as client:
|
||||
await client.subscribe(f"{self.command_topic}/+")
|
||||
|
||||
for command in self.commands.values():
|
||||
props = Properties(PacketTypes.PUBLISH)
|
||||
props.ResponseTopic = "asset/client/response"
|
||||
props.CorrelationData = b"req-42"
|
||||
|
||||
await self.mqtt_client.publish(f"{self.command_topic}/{command.name}/schema", str(command.schema), qos=1, retain=True, properties=props)
|
||||
for k,v in command.additional_properties.items():
|
||||
await self.mqtt_client.publish(f"{self.command_topic}/{command.name}/{k}", str(v), qos=1, retain=True)
|
||||
|
||||
await asyncio.gather(
|
||||
self.serial_reader_task(),
|
||||
self.mqtt_command_writer_task(),
|
||||
)
|
||||
|
||||
except aiomqtt.MqttError as e:
|
||||
print(
|
||||
f"[{self.handler_id}] MQTT connection error: {e}. Reconnecting in {interval}s..."
|
||||
)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
handler = SerialMQTTHandler("example-gps", "/tmp/ttyV0")
|
||||
await handler.run()
|
||||
|
||||
asyncio.run(main())
|
||||
25
command.py
25
command.py
@ -1,10 +1,12 @@
|
||||
import jsonschema
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Command:
|
||||
def __init__(self, name: str, schema, **kwargs):
|
||||
def __init__(self, name: str, schema, handler=None, **kwargs):
|
||||
self.name = name
|
||||
self.schema = schema
|
||||
self.handler = handler
|
||||
self.additional_properties = kwargs
|
||||
|
||||
self._validator = jsonschema.validators.validator_for(
|
||||
@ -13,3 +15,24 @@ class Command:
|
||||
|
||||
def validate(self, o) -> bool:
|
||||
return self._validator.is_valid(o)
|
||||
|
||||
def __call__(self, args):
|
||||
if not self.validate(args):
|
||||
raise ValueError(f"Invalid arguments for command '{self.name}'")
|
||||
if self.handler is None:
|
||||
raise RuntimeError(f"No handler bound for command '{self.name}'")
|
||||
return self.handler(args)
|
||||
|
||||
|
||||
def command(schema, **kwargs):
|
||||
def decorator(func):
|
||||
return Command(func.__name__, schema, handler=func, **kwargs)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandResponse:
|
||||
success: bool
|
||||
message: str = None
|
||||
correlation: str = None
|
||||
|
||||
130
handler.py
Normal file
130
handler.py
Normal file
@ -0,0 +1,130 @@
|
||||
import asyncio
|
||||
import aiomqtt
|
||||
import json
|
||||
from command import Command, CommandResponse
|
||||
import inspect
|
||||
from paho.mqtt.properties import Properties
|
||||
from dataclasses import asdict
|
||||
|
||||
BAUD = 115200
|
||||
INTERVAL = 5
|
||||
|
||||
|
||||
class MQTTHandler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mqtt_client: aiomqtt.Client,
|
||||
handler_id: str,
|
||||
):
|
||||
self.handler_id = handler_id
|
||||
self.mqtt_client = mqtt_client
|
||||
|
||||
self.topic_base = f"asset/{handler_id}"
|
||||
self.command_topic = f"{self.topic_base}/command"
|
||||
self.property_topic = f"{self.topic_base}/property"
|
||||
|
||||
def get_available_commands(self):
|
||||
commands = {}
|
||||
for base in self.__class__.__mro__:
|
||||
for name, attr in vars(base).items():
|
||||
if isinstance(attr, Command):
|
||||
commands[name] = attr
|
||||
|
||||
return commands
|
||||
|
||||
async def publish_commands(self):
|
||||
for name, command in self.get_available_commands().items():
|
||||
await self.mqtt_client.publish(
|
||||
f"{self.command_topic}/{command.name}/schema",
|
||||
str(command.schema),
|
||||
qos=1,
|
||||
retain=True,
|
||||
)
|
||||
for k, v in command.additional_properties.items():
|
||||
await self.mqtt_client.publish(
|
||||
f"{self.command_topic}/{command.name}/{k}",
|
||||
str(v),
|
||||
qos=1,
|
||||
retain=True,
|
||||
)
|
||||
|
||||
async def execute_command(
|
||||
self, command_name: str, payload: str, properties: Properties = None
|
||||
):
|
||||
|
||||
async def respond(success: bool, message: str = None):
|
||||
if (
|
||||
properties is not None
|
||||
and hasattr(properties, "ResponseTopic")
|
||||
and properties.ResponseTopic is not None
|
||||
):
|
||||
correlation = (
|
||||
properties.CorrelationData.decode("utf-8")
|
||||
if hasattr(properties, "CorrelationData")
|
||||
else None
|
||||
)
|
||||
await self.mqtt_client.publish(
|
||||
properties.ResponseTopic,
|
||||
json.dumps(
|
||||
asdict(CommandResponse(success, str(message), correlation))
|
||||
),
|
||||
qos=1,
|
||||
retain=False,
|
||||
)
|
||||
|
||||
try:
|
||||
command = self.get_available_commands()[command_name]
|
||||
argument = json.loads(payload)
|
||||
result = await command(argument)
|
||||
await respond(True, result)
|
||||
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
await respond(False, f"Failed to parse payload as JSON: {e}")
|
||||
except ValueError as e:
|
||||
await respond(False, f"Command payload does not match expected schema: {e}")
|
||||
except Exception as e:
|
||||
print(f"Failed to execute command {command_name} with unknown cause: ", e)
|
||||
await respond(False, "Unexpected error")
|
||||
|
||||
async def mqtt_command_writer_task(self):
|
||||
async for message in self.mqtt_client.messages:
|
||||
topic = str(message.topic)
|
||||
payload = message.payload.decode("utf-8")
|
||||
|
||||
if topic.startswith(self.command_topic):
|
||||
command_name = topic.removeprefix(f"{self.command_topic}/")
|
||||
await self.execute_command(command_name, payload, message.properties)
|
||||
|
||||
async def run(self):
|
||||
while True:
|
||||
try:
|
||||
async with self.mqtt_client as client:
|
||||
await client.subscribe(f"{self.command_topic}/+")
|
||||
await self.publish_commands()
|
||||
|
||||
tasks = [self.mqtt_command_writer_task()]
|
||||
|
||||
# Inspect instance methods
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if callable(attr) and getattr(attr, "_is_task", False):
|
||||
if not inspect.iscoroutinefunction(attr):
|
||||
raise TypeError(
|
||||
f"@task can only decorate async methods: {attr_name}"
|
||||
)
|
||||
tasks.append(attr()) # call it, returns coroutine
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
except aiomqtt.MqttError as e:
|
||||
print(
|
||||
f"[{self.handler_id}] MQTT connection error: {e}. Reconnecting in {INTERVAL}s..."
|
||||
)
|
||||
await asyncio.sleep(INTERVAL)
|
||||
|
||||
|
||||
def task(func):
|
||||
"""Decorator to mark async methods for automatic gathering."""
|
||||
func._is_task = True
|
||||
return func
|
||||
37
main.py
Executable file
37
main.py
Executable file
@ -0,0 +1,37 @@
|
||||
#! /usr/bin/env python3
|
||||
|
||||
import aiomqtt
|
||||
import aioserial
|
||||
import asyncio
|
||||
import paho
|
||||
import uuid
|
||||
|
||||
from ubxhandler import UBXHandler
|
||||
|
||||
BAUD = 115200
|
||||
|
||||
async def main():
|
||||
handler_id = "example-gps"
|
||||
mqtt_host = "127.0.0.1"
|
||||
mqtt_port = 1883
|
||||
|
||||
client_id = f"{handler_id}-{uuid.uuid4()}"
|
||||
mqtt_client = aiomqtt.Client(
|
||||
mqtt_host,
|
||||
port=mqtt_port,
|
||||
identifier=client_id,
|
||||
protocol=paho.mqtt.client.MQTTv5,
|
||||
)
|
||||
|
||||
serial_port = aioserial.AioSerial(
|
||||
port="/tmp/ttyV0",
|
||||
baudrate=BAUD,
|
||||
timeout=0.05, # 50 ms
|
||||
)
|
||||
|
||||
handler = UBXHandler(mqtt_client, "example-gps", serial_port)
|
||||
await handler.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
62
ubxhandler.py
Normal file
62
ubxhandler.py
Normal file
@ -0,0 +1,62 @@
|
||||
import asyncio
|
||||
import aioserial
|
||||
import aiomqtt
|
||||
import pyubx2
|
||||
|
||||
from command import command
|
||||
from handler import MQTTHandler, task
|
||||
|
||||
|
||||
class UBXHandler(MQTTHandler):
|
||||
def __init__(
|
||||
self,
|
||||
mqtt_client: aiomqtt.Client,
|
||||
handler_id: str,
|
||||
serial_port: aioserial.AioSerial
|
||||
):
|
||||
super().__init__(mqtt_client, handler_id)
|
||||
|
||||
self.serial_port = serial_port
|
||||
|
||||
@command({"type": "number"}, description="An example command")
|
||||
async def example_cmd(args):
|
||||
print(f"Executing command with args {args}")
|
||||
|
||||
async def parse_serial(self):
|
||||
buffer = bytearray()
|
||||
|
||||
class StreamWrapper:
|
||||
def read(inner_self, n=1):
|
||||
if not buffer:
|
||||
raise BlockingIOError
|
||||
out = buffer[:n]
|
||||
del buffer[:n]
|
||||
return bytes(out)
|
||||
|
||||
ubr = pyubx2.UBXReader(StreamWrapper(), parsing=True)
|
||||
|
||||
while True:
|
||||
chunk = await self.serial_port.read_async(200)
|
||||
if chunk:
|
||||
buffer.extend(chunk)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw, parsed = ubr.read()
|
||||
if raw is None:
|
||||
break
|
||||
yield raw, parsed
|
||||
except (pyubx2.UBXStreamError, BlockingIOError, Exception):
|
||||
pass
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@task
|
||||
async def serial_reader_task(self):
|
||||
async for raw, parsed in self.parse_serial():
|
||||
if isinstance(parsed, pyubx2.UBXMessage):
|
||||
for name, value in vars(parsed).items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
topic = f"{self.topic_base}/{parsed.identity}/{name}"
|
||||
await self.mqtt_client.publish(topic, value, qos=1, retain=True)
|
||||
Loading…
Reference in New Issue
Block a user