Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

* Added `retain` option to `publish()` method for transports that allow retaining messages for new subscribers.
* Added support for MQTT over secure WebSockets using `transport="websockets"` and `tls_set()`.

### Changed

### Removed
Expand Down Expand Up @@ -236,4 +239,3 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

### Removed

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ select = ["E", "F", "I"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["I001"]
"tests/*" = ["I001"]
"tasks.py" = ["I001"]

[tool.pytest.ini_options]
Expand Down
10 changes: 6 additions & 4 deletions src/compas_eve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Type
from typing import Union


DEFAULT_TRANSPORT = None


Expand Down Expand Up @@ -59,7 +58,7 @@ def id_counter(self) -> int:
self._id_counter += 1
return self._id_counter

def publish(self, topic: "Topic", message: Union["Message", dict]) -> None:
def publish(self, topic: "Topic", message: Union["Message", dict], **options: Any) -> None:
pass

def subscribe(self, topic: "Topic", callback: Callable) -> Optional[str]:
Expand Down Expand Up @@ -173,19 +172,22 @@ def message_published(self, message: Union[Message, dict]) -> None:
"""Handler called when a message has been published."""
pass

def publish(self, message: Union[Message, dict]) -> None:
def publish(self, message: Union[Message, dict], **options: Any) -> None:
"""Publish a message to the topic.

Parameters
----------
message
The message to publish.
**options
Transport-specific options passed through to the underlying transport.
For example, ``retain=True`` on MQTT and InMemory transports.
"""
# TODO: check if message type matches self.topic.message_type declared
if not self.is_advertised:
self.advertise()

self.transport.publish(self.topic, message)
self.transport.publish(self.topic, message, **options)
self.message_published(message)

def advertise(self) -> None:
Expand Down
13 changes: 12 additions & 1 deletion src/compas_eve/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ class InMemoryTransport(Transport, EventEmitterMixin):
def __init__(self, codec: Optional[MessageCodec] = None, *args, **kwargs):
super(InMemoryTransport, self).__init__(codec=codec, *args, **kwargs)
self._local_callbacks = {}
self._retained = {}

def on_ready(self, callback: Callable):
"""In-memory transport is always ready, it will immediately trigger the callback."""
callback()

def publish(self, topic: Topic, message: Message):
def publish(self, topic: Topic, message: Message, **options):
"""Publish a message to a topic.

Parameters
Expand All @@ -38,12 +39,20 @@ def publish(self, topic: Topic, message: Message):
Instance of the topic to publish to.
message
Instance of the message to publish.
retain : bool, optional
If True, the last message on this topic is stored and delivered
immediately to any new subscriber. Defaults to False.
"""
retain = options.pop("retain", False)
if options:
raise TypeError("publish() got unexpected options for InMemoryTransport: {}".format(", ".join(options)))
event_key = "event:{}".format(topic.name)

def _callback(**kwargs):
encoded_message = self.codec.encode(message)
encoded_message_bytes = encoded_message if isinstance(encoded_message, bytes) else encoded_message.encode("utf-8")
if retain:
self._retained[topic.name] = encoded_message_bytes
self.emit(event_key, encoded_message_bytes)

self.on_ready(_callback)
Expand Down Expand Up @@ -75,6 +84,8 @@ def _local_callback(msg):

def _callback(**kwargs):
self.on(event_key, _local_callback)
if topic.name in self._retained:
_local_callback(self._retained[topic.name])

self._local_callbacks[subscribe_id] = _local_callback

Expand Down
38 changes: 33 additions & 5 deletions src/compas_eve/mqtt/mqtt_paho.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import uuid
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional

import paho.mqtt.client as mqtt
Expand Down Expand Up @@ -30,12 +32,30 @@ class MqttTransport(Transport, EventEmitterMixin):
MQTT broker port, defaults to `1883`.
client_id
Client ID for the MQTT connection. If not provided, a unique ID will be generated.
transport
Paho MQTT transport to use. Defaults to `"tcp"`. Use `"websockets"` for MQTT over WebSockets.
tls
If True, enables TLS by calling `client.tls_set()` before connecting.
tls_options
Optional keyword arguments for `client.tls_set()`, e.g. `ca_certs`, `certfile`,
`keyfile`, `cert_reqs`, `tls_version`, or `ciphers`. Providing this also enables TLS.
codec
The codec to use for encoding and decoding messages.
If not provided, defaults to [JsonMessageCodec][compas_eve.codecs.JsonMessageCodec].
"""

def __init__(self, host: str, port: int = 1883, client_id: Optional[str] = None, codec: Optional[MessageCodec] = None, *args, **kwargs):
def __init__(
self,
host: str,
port: int = 1883,
client_id: Optional[str] = None,
codec: Optional[MessageCodec] = None,
transport: str = "tcp",
tls: bool = False,
tls_options: Optional[Dict[str, Any]] = None,
*args,
**kwargs,
):
super(MqttTransport, self).__init__(codec=codec, *args, **kwargs)
self.host = host
self.port = port
Expand All @@ -45,10 +65,12 @@ def __init__(self, host: str, port: int = 1883, client_id: Optional[str] = None,
if client_id is None:
client_id = "compas_eve_{}".format(uuid.uuid4().hex[:8])
if PAHO_MQTT_V2_AVAILABLE:
self.client = mqtt.Client(client_id=client_id, callback_api_version=CallbackAPIVersion.VERSION1)
self.client = mqtt.Client(client_id=client_id, callback_api_version=CallbackAPIVersion.VERSION1, transport=transport)
else:
self.client = mqtt.Client(client_id=client_id)
self.client = mqtt.Client(client_id=client_id, transport=transport)
self.client.on_connect = self._on_connect
if tls or tls_options is not None:
self.client.tls_set(**(tls_options or {}))
self.client.connect(self.host, self.port)
self.client.loop_start()

Expand All @@ -73,7 +95,7 @@ def on_ready(self, callback: Callable):
else:
self.once("ready", callback)

def publish(self, topic: Topic, message: Message):
def publish(self, topic: Topic, message: Message, **options):
"""Publish a message to a topic.

Parameters
Expand All @@ -82,11 +104,17 @@ def publish(self, topic: Topic, message: Message):
Instance of the topic to publish to.
message
Instance of the message to publish.
retain : bool, optional
If True, the broker retains the last message on this topic and
delivers it immediately to any new subscriber. Defaults to False.
"""
retain = options.pop("retain", False)
if options:
raise TypeError("publish() got unexpected options for MqttTransport: {}".format(", ".join(options)))

def _callback(**kwargs):
encoded_message = self.codec.encode(message)
self.client.publish(topic.name, encoded_message)
self.client.publish(topic.name, encoded_message, retain=retain)

self.on_ready(_callback)

Expand Down
16 changes: 10 additions & 6 deletions src/compas_eve/zenoh/zenoh_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_ready(self, callback: Callable) -> None:
else:
self.once("ready", callback)

def publish(self, topic: Topic, message: Message) -> None:
def publish(self, topic: Topic, message: Message, **options: Any) -> None:
"""Publish a message to a topic.

Parameters
Expand All @@ -74,13 +74,16 @@ def publish(self, topic: Topic, message: Message) -> None:
message
Instance of the message to publish.
"""
if options:
raise TypeError("publish() got unexpected options for ZenohTransport: {}".format(", ".join(options)))

def _callback(**kwargs: Any) -> None:
if self._get_topic_name(topic) not in self._publishers:
self._publishers[self._get_topic_name(topic)] = self.session.declare_publisher(self._get_topic_name(topic))
topic_name = self._get_topic_name(topic)
if topic_name not in self._publishers:
self._publishers[topic_name] = self.session.declare_publisher(topic_name)

encoded_message = self.codec.encode(message)
self._publishers[self._get_topic_name(topic)].put(encoded_message)
self._publishers[topic_name].put(encoded_message)

self.on_ready(_callback)

Expand Down Expand Up @@ -111,8 +114,9 @@ def _zenoh_handler(sample: Any) -> None:
self.emit(event_key, message_obj)

def _subscribe_callback(**kwargs: Any) -> None:
if self._get_topic_name(topic) not in self._subscribers:
self._subscribers[self._get_topic_name(topic)] = self.session.declare_subscriber(self._get_topic_name(topic), _zenoh_handler)
topic_name = self._get_topic_name(topic)
if topic_name not in self._subscribers:
self._subscribers[topic_name] = self.session.declare_subscriber(topic_name, _zenoh_handler)

self.on(event_key, _local_callback)

Expand Down
36 changes: 36 additions & 0 deletions tests/integration/test_transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
HOST = "localhost"


@pytest.fixture
def mqtt_tx():
tx = MqttTransport(HOST)
yield tx
tx.close()


@pytest.fixture(params=["mqtt", "zenoh"])
def tx(request):
if request.param == "mqtt":
Expand Down Expand Up @@ -217,3 +224,32 @@ def callback(msg):
assert received, "Message not received"
assert result["value"].name == "Jazz"
assert result["value"]["name"] == "Jazz", "Messages should be accessible as dict"


def test_mqtt_retain_delivers_to_late_subscriber(mqtt_tx):
topic = Topic("/messages_compas_eve_test/test_retain/", Message)

pub = Publisher(topic, transport=mqtt_tx)
pub.publish(Message(value=42), retain=True)
time.sleep(0.2)

result = dict(value=None, event=Event())

def callback(msg):
result["value"] = msg.value
result["event"].set()

Subscriber(topic, callback, transport=mqtt_tx).subscribe()

received = result["event"].wait(timeout=3)
assert received, "Retained message not delivered to late subscriber"
assert result["value"] == 42

# Clean up: publish empty retained message to clear broker state
pub.publish(Message(), retain=True)


def test_mqtt_unknown_option_raises(mqtt_tx):
topic = Topic("/messages_compas_eve_test/test_bad_option/", Message)
with pytest.raises(TypeError):
Publisher(topic, transport=mqtt_tx).publish(Message(value=1), unknown_flag=True)
1 change: 1 addition & 0 deletions tests/unit/test_codecs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from compas.geometry import Frame

from compas_eve import Message
from compas_eve.codecs import JsonMessageCodec
from compas_eve.codecs import ProtobufMessageCodec
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from threading import Event

import pytest

from compas_eve import InMemoryTransport
from compas_eve import Message
from compas_eve import Publisher
Expand Down Expand Up @@ -104,3 +106,64 @@ def callback(msg):
def test_message_str():
msg = Message(a=3)
assert str(msg) == "{'a': 3}"


def test_retain_delivers_to_late_subscriber():
tx = InMemoryTransport()
topic = Topic("/messages_compas_eve_test/retain/", Message)

Publisher(topic, transport=tx).publish(Message(value=42), retain=True)

result = dict(value=None, event=Event())

def callback(msg):
result["value"] = msg.value
result["event"].set()

Subscriber(topic, callback, transport=tx).subscribe()

received = result["event"].wait(timeout=1)
assert received, "Retained message not delivered to late subscriber"
assert result["value"] == 42


def test_retain_last_message_wins():
tx = InMemoryTransport()
topic = Topic("/messages_compas_eve_test/retain_last/", Message)
pub = Publisher(topic, transport=tx)

pub.publish(Message(value=1), retain=True)
pub.publish(Message(value=2), retain=True)

result = dict(value=None, event=Event())

def callback(msg):
result["value"] = msg.value
result["event"].set()

Subscriber(topic, callback, transport=tx).subscribe()

received = result["event"].wait(timeout=1)
assert received, "Retained message not delivered"
assert result["value"] == 2


def test_no_retain_does_not_deliver_to_late_subscriber():
tx = InMemoryTransport()
topic = Topic("/messages_compas_eve_test/no_retain/", Message)

Publisher(topic, transport=tx).publish(Message(value=42))

event = Event()
Subscriber(topic, lambda m: event.set(), transport=tx).subscribe()

received = event.wait(timeout=0.2)
assert not received, "Non-retained message should not be delivered to late subscriber"


def test_unknown_option_raises():
tx = InMemoryTransport()
topic = Topic("/messages_compas_eve_test/bad_option/", Message)

with pytest.raises(TypeError):
Publisher(topic, transport=tx).publish(Message(value=1), unknown_flag=True)
Loading
Loading