diff --git a/src/labthings_fastapi/actions.py b/src/labthings_fastapi/actions.py index 888acfef..77b7dc46 100644 --- a/src/labthings_fastapi/actions.py +++ b/src/labthings_fastapi/actions.py @@ -37,11 +37,12 @@ TypeVar, overload, ) -from weakref import WeakSet import weakref from fastapi import APIRouter, FastAPI, HTTPException, Request, Body, BackgroundTasks from pydantic import BaseModel, create_model +from labthings_fastapi.message_broker import Message + from .middleware.url_for import URLFor from .base_descriptor import ( @@ -68,7 +69,6 @@ ) from .thing_description import type_to_dataschema from .thing_description._model import ActionAffordance, ActionOp, Form, LinkElement -from .utilities import labthings_data if TYPE_CHECKING: @@ -247,6 +247,20 @@ def response(self) -> InvocationModel: log=self.log, ) + def _publish_status(self) -> None: + """Publish a status change event to any observers. + + This should be called after each change to ``self._status`` + """ + self.thing._thing_server_interface.publish( + Message( + thing=self.thing.name, + affordance=self.action.name, # type: ignore[attr-defined] + message_type="action", + payload=self._status.value, + ) + ) + def run(self) -> None: """Run the action and track progress. @@ -282,7 +296,7 @@ def run(self) -> None: add_thing_log_destination(self.id, self._log) with invocation_contexts.set_invocation_id(self.id): try: - action.emit_changed_event(self.thing, self._status.value) + self._publish_status() thing = self.thing kwargs = model_to_dict(self.input) @@ -298,7 +312,7 @@ def run(self) -> None: with self._status_lock: self._status = InvocationStatus.RUNNING self._start_time = datetime.datetime.now() - action.emit_changed_event(self.thing, self._status.value) + self._publish_status() # Actually run the action ret = action.func(thing, **kwargs, **self.dependencies) @@ -306,13 +320,13 @@ def run(self) -> None: with self._status_lock: self._return_value = ret self._status = InvocationStatus.COMPLETED - action.emit_changed_event(self.thing, self._status.value) + self._publish_status() except InvocationCancelledError: logger.info(f"Invocation {self.id} was cancelled.") with self._status_lock: self._status = InvocationStatus.CANCELLED - action.emit_changed_event(self.thing, self._status.value) + self._publish_status() except Exception as e: # skipcq: PYL-W0703 # First log if isinstance(e, InvocationError): @@ -332,7 +346,7 @@ def run(self) -> None: with self._status_lock: self._status = InvocationStatus.ERROR self._exception = e - action.emit_changed_event(self.thing, self._status.value) + self._publish_status() finally: with self._status_lock: self._end_time = datetime.datetime.now() @@ -810,70 +824,13 @@ def instance_get(self, obj: OwnerT) -> Callable[ActionParams, ActionReturn]: """ @wraps(self.func) - def wrapped(*args: Any, **kwargs: Any) -> Any: # noqa: DOC + def wrapped(*args: Any, **kwargs: Any) -> Any: # noqa: DOC101, DOC103, DOC201 """Acquire the lock then run `func` with supplied arguments.""" with self.context_for_func(obj): return self.func(*args, **kwargs) return partial(wrapped, obj) - def _observers_set(self, obj: Thing) -> WeakSet: - """Return a set used to notify changes. - - Note that we need to supply the `~lt.Thing` we are looking at, as in - general there may be more than one object of the same type, and - descriptor instances are shared between all instances of their class. - - :param obj: The `~lt.Thing` on which the action is being observed. - - :return: a weak set of callables to notify on changes to the action. - This is used by websocket endpoints. - """ - ld = labthings_data(obj) - if self.name not in ld.action_observers: - ld.action_observers[self.name] = WeakSet() - return ld.action_observers[self.name] - - def emit_changed_event(self, obj: Thing, status: str) -> None: - """Notify subscribers that the action status has changed. - - This function is run from within the `.Invocation` thread that - is created when an action is called. It must be run from a thread - as it is communicating with the event loop via an `asyncio` blocking - portal. Async code must not use the blocking portal as it can deadlock - the event loop. - - :param obj: The `~lt.Thing` on which the action is being observed. - :param status: The status of the action, to be sent to observers. - """ - obj._thing_server_interface.start_async_task_soon( - self.emit_changed_event_async, - obj, - status, - ) - - async def emit_changed_event_async(self, obj: Thing, value: Any) -> None: - """Notify subscribers that the action status has changed. - - This is an async function that must be run in the `anyio` event loop. - It will send messages to each observer to notify them that something - has changed. - - :param obj: The `~lt.Thing` on which the action is defined. - `.ActionDescriptor` objects are unique to the class, but there may - be more than one `~lt.Thing` attached to a server with the same class. - We use ``obj`` to look up the observers of the current `~lt.Thing`. - :param value: The action status to communicate to the observers. - """ - action_name = self.name - for observer in self._observers_set(obj): - await observer.send( - { - "messageType": "actionStatus", - "data": {"action name": action_name, "status": value}, - } - ) - def add_to_fastapi(self, app: FastAPI, thing: Thing) -> None: """Add this action to a FastAPI app, bound to a particular Thing. diff --git a/src/labthings_fastapi/exceptions.py b/src/labthings_fastapi/exceptions.py index 472d5269..796af69a 100644 --- a/src/labthings_fastapi/exceptions.py +++ b/src/labthings_fastapi/exceptions.py @@ -41,7 +41,7 @@ class ReadOnlyPropertyError(AttributeError): class PropertyNotObservableError(RuntimeError): """The property is not observable. - This exception is raised when `~lt.Thing.observe_property` is called with a + This exception is raised when trying to observe property that is not observable. Currently, only data properties are observable: functional properties (using a getter/setter) may not be observed. diff --git a/src/labthings_fastapi/message_broker.py b/src/labthings_fastapi/message_broker.py new file mode 100644 index 00000000..46f427ae --- /dev/null +++ b/src/labthings_fastapi/message_broker.py @@ -0,0 +1,122 @@ +"""Handle pub-sub style events. + +Both properties and actions can emit events that may be observed. This module handles +all the pub-sub messaging in LabThings. +""" + +import anyio +from pydantic.dataclasses import dataclass +from typing import Any, Literal +from weakref import WeakSet + +from anyio.abc import ObjectSendStream + + +@dataclass +class Message: + """A pub-sub event message. + + This is the message that is sent when a property or action generates + an event. + + This is a pydantic dataclass, so we validate the message. This might + change in the future for performance reasons. + + :param thing: The name of the Thing generating the event. + :param affordance: The name of the affordance generating the event. + :param message: The message to send. + """ + + thing: str + affordance: str + message_type: Literal["property", "action", "event"] + payload: Any + + +class MessageBroker: + r"""A class that relays pub/sub messages. + + This class takes care of relaying messages to streams that have subscribed to them. + It does not format messages or handle any details of e.g. websocket protocol. + + Subscriptions require an `ObjectSendStream[Message]` and each time a `Message` + matching the subscription parameters (``thing`` and ``affordance``) is published, + it will be sent on that stream. + + The broker does not validate thing or affordance names: that's up to the code + calling `MessageBroker.subscribe`\ . + """ + + def __init__(self) -> None: + """Initialise the message broker.""" + # Note that we use a weak set below, so that when a websocket disconnects, + # its stream is removed automatically. + self._subscriptions: dict[ + str, dict[str, WeakSet[ObjectSendStream[Message]]] + ] = {} + + def subscribe( + self, thing: str, affordance: str, stream: ObjectSendStream[Message] + ) -> None: + """Subscribe to messages from a particular affordance. + + Note that this method is not async - it just registers the stream and so + can be run from any thread. + + :param thing: The name of the `.Thing` being subscribed to. + :param affordance: The name of the affordance being subscribed to. + :param stream: A stream to send the messages to. + :raises TypeError: if the `thing` argument is not a string. + """ + if not isinstance(thing, str): + raise TypeError(f"The `thing` argument should be a string, not {thing}.") + if thing not in self._subscriptions: + self._subscriptions[thing] = {} + if affordance not in self._subscriptions[thing]: + self._subscriptions[thing][affordance] = WeakSet() + self._subscriptions[thing][affordance].add(stream) + + def unsubscribe( + self, thing: str, affordance: str, stream: ObjectSendStream[Message] + ) -> None: + """Unsubscribe a stream from messages from a particular affordance. + + :param thing: The name of the `.Thing` being unsubscribed from. + :param affordance: The name of the affordance being unsubscribed from. + :param stream: The stream to unsubscribe. + :raises KeyError: if there is no such subscription. + :raises TypeError: if the `thing` argument is not a string. + """ + if not isinstance(thing, str): + raise TypeError(f"The `thing` argument should be a string, not {thing}.") + try: + self._subscriptions[thing][affordance].discard(stream) + except KeyError as e: + raise e + + async def publish(self, message: Message) -> None: + """Publish a message. + + This async method will relay the message to any subscriber streams. + + :param message: the message to send. + """ + try: + subscriptions = self._subscriptions[message.thing][message.affordance] + except KeyError: + return # No subscribers for this thing. + for stream in subscriptions: + await stream.send(message) + + async def close_streams(self) -> None: + """Close all streams that are subscribed to receive messages. + + This should be called when the server shuts down. + """ + # We use a task group so we shut down all streams concurrently, rather + # than waiting for each one to close. + async with anyio.create_task_group() as tg: + for thing_subs in self._subscriptions.values(): + for subs in thing_subs.values(): + for stream in subs: + tg.start_soon(stream.aclose) diff --git a/src/labthings_fastapi/properties.py b/src/labthings_fastapi/properties.py index a81f157b..dd946391 100644 --- a/src/labthings_fastapi/properties.py +++ b/src/labthings_fastapi/properties.py @@ -60,7 +60,6 @@ class attribute. Documentation is in strings immediately following the TYPE_CHECKING, ) from typing_extensions import Self, TypedDict -from weakref import WeakSet from fastapi import Body, FastAPI from pydantic import ( @@ -73,6 +72,8 @@ class attribute. Documentation is in strings immediately following the with_config, ) +from labthings_fastapi.message_broker import Message + from .thing_description import type_to_dataschema from .thing_description._model import ( DataSchema, @@ -82,7 +83,6 @@ class attribute. Documentation is in strings immediately following the ) from .utilities import ( LabThingsRootModelWrapper, - labthings_data, wrap_plain_types_in_rootmodel, ) from .utilities.introspection import return_type @@ -403,6 +403,14 @@ def __init__( except UnsupportedConstraintError: raise + observable: bool = False + """Whether or not the property may be observed. + + If `observable` is `True` then a websocket connection can register to be notified + when the property changes. By default this is `True` for data properties and + `False` for functional properties. + """ + @staticmethod def _validate_constraints(constraints: Mapping[str, Any]) -> FieldConstraints: """Validate an untyped dictionary of constraints. @@ -755,9 +763,7 @@ def instance_get(self, obj: Owner) -> Value: obj.__dict__[self.name] = self._default_factory() return obj.__dict__[self.name] - def __set__( - self, obj: Owner, value: Value, emit_changed_event: bool = True - ) -> None: + def __set__(self, obj: Owner, value: Value) -> None: """Set the property's value. This sets the property's value, and notifies any observers. @@ -768,7 +774,6 @@ def __set__( :param obj: the `~lt.Thing` to which we are attached. :param value: the new value for the property. - :param emit_changed_event: whether to emit a changed event. """ with obj._thing_server_interface._optionally_hold_global_lock( self.use_global_lock @@ -778,8 +783,16 @@ def __set__( obj.__dict__[self.name] = property_info.validate(value) else: obj.__dict__[self.name] = value - if emit_changed_event: - self.emit_changed_event(obj, value) + obj._thing_server_interface.publish( + Message(obj.name, self.name, "property", value) + ) + + observable: bool = True + """Whether or not the property may be observed. + + If `observable` is `True` then a websocket connection can register to be notified + when the property changes. By default this is `True` for data properties. + """ def get_default(self, obj: Owner | None) -> Value: """Return the default value of this property. @@ -802,56 +815,6 @@ def reset(self, obj: Owner) -> None: """ self.__set__(obj, self.get_default(obj)) - def _observers_set(self, obj: Thing) -> WeakSet: - """Return the observers of this property. - - Each observer in this set will be notified when the property is changed. - See ``.DataProperty.emit_changed_event`` - - :param obj: the `~lt.Thing` to which we are attached. - - :return: the set of observers corresponding to ``obj``. - """ - ld = labthings_data(obj) - if self.name not in ld.property_observers: - ld.property_observers[self.name] = WeakSet() - return ld.property_observers[self.name] - - def emit_changed_event(self, obj: Thing, value: Value) -> None: - """Notify subscribers that the property has changed. - - This function is run when properties are updated. It must be run from - within a thread. This could be the `Invocation` thread of a running action, or - the property should be updated over via a client/http. It must be run from a - thread as it is communicating with the event loop via an `asyncio` blocking - portal and can cause deadlock if run in the event loop. - - This method will raise a `.ServerNotRunningError` if the event loop is not - running, and should only be called after the server has started. - - :param obj: the `~lt.Thing` to which we are attached. - :param value: the new property value, to be sent to observers. - """ - obj._thing_server_interface.start_async_task_soon( - self.emit_changed_event_async, - obj, - value, - ) - - async def emit_changed_event_async(self, obj: Thing, value: Value) -> None: - """Notify subscribers that the property has changed. - - This function may only be run in the `anyio` event loop. See - `.DataProperty.emit_changed_event`. - - :param obj: the `~lt.Thing` to which we are attached. - :param value: the new property value, to be sent to observers. - """ - for observer in self._observers_set(obj): - await observer.send( - {"messageType": "propertyStatus", "data": {self._name: value}} - ) - class FunctionalProperty(BaseProperty[Owner, Value], Generic[Owner, Value]): """A property that uses a getter and a setter. @@ -1223,6 +1186,11 @@ def default(self) -> Value: # noqa: DOC201 """ return self.get_descriptor().get_default(self.owning_object) + @builtins.property + def is_observable(self) -> bool: # noqa: DOC201 + """Whether the property may be observed.""" + return self.get_descriptor().observable + @builtins.property def is_resettable(self) -> bool: # noqa: DOC201 """Whether the property may be reset using the ``reset()`` method.""" @@ -1421,20 +1389,6 @@ class BaseSetting(BaseProperty[Owner, Value], Generic[Owner, Value]): two concrete implementations: `.DataSetting` and `.FunctionalSetting`\ . """ - def set_without_emit(self, obj: Owner, value: Value) -> None: - """Set the setting's value without emitting an event. - - This is used to set the setting's value without notifying observers. - It is used during initialisation to set the value from disk before - the server is fully started. - - :param obj: the `~lt.Thing` to which we are attached. - :param value: the new value of the setting. - - :raises NotImplementedError: this method should be implemented in subclasses. - """ - raise NotImplementedError("This method should be implemented in subclasses.") - def descriptor_info(self, owner: Owner | None = None) -> SettingInfo[Owner, Value]: r"""Return an object that allows access to this descriptor's metadata. @@ -1464,32 +1418,17 @@ class DataSetting( The setting otherwise acts just like a normal variable. """ - def __set__( - self, obj: Owner, value: Value, emit_changed_event: bool = True - ) -> None: + def __set__(self, obj: Owner, value: Value) -> None: """Set the setting's value. This will cause the settings to be saved to disk. :param obj: the `~lt.Thing` to which we are attached. :param value: the new value of the setting. - :param emit_changed_event: whether to emit a changed event. """ - super().__set__(obj, value, emit_changed_event) + super().__set__(obj, value) obj.save_settings() - def set_without_emit(self, obj: Owner, value: Value) -> None: - """Set the property's value, but do not emit event to notify the server. - - This function is not expected to be used externally. It is called during - initial setup so that the setting can be set from disk before the server - is fully started. - - :param obj: the `~lt.Thing` to which we are attached. - :param value: the new value of the setting. - """ - super().__set__(obj, value, emit_changed_event=False) - class FunctionalSetting( FunctionalProperty[Owner, Value], BaseSetting[Owner, Value], Generic[Owner, Value] @@ -1522,34 +1461,12 @@ def __set__(self, obj: Owner, value: Value) -> None: super().__set__(obj, value) obj.save_settings() - def set_without_emit(self, obj: Owner, value: Value) -> None: - """Set the property's value, but do not emit event to notify the server. - - This function is not expected to be used externally. It is called during - initial setup so that the setting can be set from disk before the server - is fully started. - - :param obj: the `~lt.Thing` to which we are attached. - :param value: the new value of the setting. - """ - # FunctionalProperty does not emit changed events, so no special - # behaviour is needed. - super().__set__(obj, value) - class SettingInfo( PropertyInfo[BaseSetting[Owner, Value], Owner, Value], Generic[Owner, Value] ): """Access to the metadata of a setting.""" - def set_without_emit(self, value: Value) -> None: - """Set the value of the setting, but don't emit a notification. - - :param value: the new value for the setting. - """ - obj = self.owning_object_or_error() - self.get_descriptor().set_without_emit(obj, value) - class SettingCollection(DescriptorInfoCollection[Owner, SettingInfo], Generic[Owner]): """Access to metadata on all the properties of a `~lt.Thing` instance or subclass. diff --git a/src/labthings_fastapi/server/__init__.py b/src/labthings_fastapi/server/__init__.py index 35fa41e4..66ab159d 100644 --- a/src/labthings_fastapi/server/__init__.py +++ b/src/labthings_fastapi/server/__init__.py @@ -24,6 +24,7 @@ import uvicorn from labthings_fastapi.exceptions import GlobalLockBusyError +from labthings_fastapi.message_broker import MessageBroker from ..middleware.url_for import url_for_middleware from ..thing_slots import ThingSlot @@ -146,6 +147,7 @@ def __init__( self._set_url_for_middleware() self._add_exception_handlers() self.action_manager = ActionManager() + self.message_broker = MessageBroker() self.app.include_router(self.action_manager.router(), prefix=self.api_prefix) self.app.include_router(blob.router, prefix=self.api_prefix) self.app.include_router(self._things_view_router(), prefix=self.api_prefix) diff --git a/src/labthings_fastapi/testing.py b/src/labthings_fastapi/testing.py index 029f43c7..52bc86f4 100644 --- a/src/labthings_fastapi/testing.py +++ b/src/labthings_fastapi/testing.py @@ -18,6 +18,7 @@ from unittest.mock import Mock from labthings_fastapi.global_lock import GlobalLock +from labthings_fastapi.message_broker import Message from .utilities import class_attributes from .thing_slots import ThingSlot @@ -92,6 +93,12 @@ def start_async_task_soon( f.cancel() return f + def publish(self, message: Message) -> None: + """Silently ignore published events. + + :param message: a message to publish. + """ + @property def settings_folder(self) -> str: """The path to a folder where persistent files may be saved. @@ -214,6 +221,7 @@ def mock_thing_instance(spec: type[ThingSubclass]) -> ThingSubclass: mock.__name__ = "Mock{spec.__name__}" mock.__module__ = "mock_module" mock._thing_server_interface = MockThingServerInterface(mock.__name__) + mock.name = mock._thing_server_interface.name return mock diff --git a/src/labthings_fastapi/thing.py b/src/labthings_fastapi/thing.py index a15c3831..fb945505 100644 --- a/src/labthings_fastapi/thing.py +++ b/src/labthings_fastapi/thing.py @@ -16,24 +16,20 @@ from json.decoder import JSONDecodeError from fastapi.encoders import jsonable_encoder from fastapi import Request, WebSocket -from anyio.abc import ObjectSendStream from anyio.to_thread import run_sync from .logs import THING_LOGGER from .properties import ( - BaseProperty, - DataProperty, PropertyCollection, SettingCollection, ) -from .actions import ActionCollection, ActionDescriptor +from .actions import ActionCollection from .base_descriptor import OptionallyBoundDescriptor from .thing_description._model import ThingDescription, NoSecurityScheme from .utilities import class_attributes from .thing_description import validation from .utilities.introspection import get_summary, get_docstring from .websockets import websocket_endpoint -from .exceptions import PropertyNotObservableError from .thing_server_interface import ThingServerInterface from .invocation_contexts import get_invocation_id from .thing_class_settings import ThingClassSettings, validate_thing_class_settings @@ -210,7 +206,7 @@ def thing_description(request: Request) -> ThingDescription: @server.app.websocket(self.path + "ws") async def websocket(ws: WebSocket) -> None: - await websocket_endpoint(self, ws) + await websocket_endpoint(self, ws, server.message_broker) def _read_settings_file(self) -> Mapping[str, Any] | None: """Read the settings file and return a mapping of saved settings or None. @@ -281,7 +277,7 @@ def load_settings(self) -> None: try: setting = self.settings[name] # Load the key from the JSON file using the setting's model - setting.set_without_emit(setting.validate(value)) + setting.set(setting.validate(value)) except ValidationError: self.logger.warning( f"Could not load setting {name} from settings file " @@ -430,36 +426,6 @@ def thing_description_dict( td_dict: dict = td.model_dump(exclude_none=True, by_alias=True) return jsonable_encoder(td_dict) - def observe_property(self, property_name: str, stream: ObjectSendStream) -> None: - """Register a stream to receive property change notifications. - - :param property_name: the property to register for. - :param stream: the stream used to send events. - - :raise KeyError: if the requested name is not defined on this Thing. - :raise PropertyNotObservableError: if the property is not observable. - """ - prop = getattr(self.__class__, property_name, None) - if not isinstance(prop, BaseProperty): - raise KeyError(f"{property_name} is not a LabThings Property") - if not isinstance(prop, DataProperty): - raise PropertyNotObservableError(f"{property_name} is not observable.") - prop._observers_set(self).add(stream) - - def observe_action(self, action_name: str, stream: ObjectSendStream) -> None: - """Register a stream to receive action status change notifications. - - :param action_name: the action to register for. - :param stream: the stream used to send events. - - :raise KeyError: if the requested name is not defined on this Thing. - """ - action = getattr(self.__class__, action_name, None) - if not isinstance(action, ActionDescriptor): - raise KeyError(f"{action_name} is not an LabThings Action") - observers = action._observers_set(self) - observers.add(stream) - def get_current_invocation_logs(self) -> list[logging.LogRecord]: """Get the log records for an on going action. diff --git a/src/labthings_fastapi/thing_server_interface.py b/src/labthings_fastapi/thing_server_interface.py index 6dedea5f..24fe09e0 100644 --- a/src/labthings_fastapi/thing_server_interface.py +++ b/src/labthings_fastapi/thing_server_interface.py @@ -18,6 +18,7 @@ from weakref import ref, ReferenceType from labthings_fastapi.global_lock import GlobalLock +from labthings_fastapi.message_broker import Message from .exceptions import FeatureNotEnabledError, ServerNotRunningError @@ -136,6 +137,22 @@ def call_async_task( raise ServerNotRunningError("Can't run async code without an event loop.") return portal.call(async_function, *args) + def publish(self, message: Message) -> None: + """Publish an event. + + Use the async event loop to notify subscribers that something has + happened. The message should contain the name of the `~lt.Thing` and affordance. + + Note that this function will do nothing if the event loop is not yet running. + + :param message: the message being published. + """ + try: + broker = self._get_server().message_broker + self.start_async_task_soon(broker.publish, message) + except ServerNotRunningError: + pass # If the server isn't running yet, we can't publish events. + @property def settings_folder(self) -> str: """The path to a folder where persistent files may be saved.""" diff --git a/src/labthings_fastapi/utilities/__init__.py b/src/labthings_fastapi/utilities/__init__.py index 51515601..75be8203 100644 --- a/src/labthings_fastapi/utilities/__init__.py +++ b/src/labthings_fastapi/utilities/__init__.py @@ -2,23 +2,15 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any, Dict, Generic, Iterable, TYPE_CHECKING, Optional, TypeVar -from weakref import WeakSet -from pydantic import BaseModel, ConfigDict, Field, RootModel, create_model -from pydantic.dataclasses import dataclass +from typing import Any, Dict, Generic, Iterable, Optional, TypeVar +from pydantic import BaseModel, Field, RootModel, create_model from labthings_fastapi.exceptions import UnsupportedConstraintError from .introspection import EmptyObject -if TYPE_CHECKING: - from ..thing import Thing - - __all__ = [ "class_attributes", "attributes", - "LabThingsObjectData", - "labthings_data", "wrap_plain_types_in_rootmodel", "model_to_dict", ] @@ -54,46 +46,6 @@ def attributes(cls: Any) -> Iterable[tuple[str, Any]]: yield name, getattr(cls, name) -LABTHINGS_DICT_KEY = "__labthings" - - -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class LabThingsObjectData: - r"""Data used by LabThings, stored on each `~lt.Thing`. - - This `pydantic.dataclass` groups together some properties used - by LabThings descriptors, to avoid cluttering the namespace of the - `~lt.Thing` subclass on which they are defined. - """ - - property_observers: Dict[str, WeakSet] = Field(default_factory=dict) - r"""The observers added to each property. - - Keys are property names, values are weak sets used by `~lt.DataProperty`\ . - """ - action_observers: Dict[str, WeakSet] = Field(default_factory=dict) - r"""The observers added to each action. - - Keys are action names, values are weak sets used by - `.ActionDescriptor`\ . - """ - - -def labthings_data(obj: Thing) -> LabThingsObjectData: - """Get (or create) a dictionary for LabThings properties. - - Ensure there is a `.LabThingsObjectData` dataclass attached to - a particular `~lt.Thing`, and return it. - - :param obj: The `~lt.Thing` we are looking for the dataclass on. - - :return: a `.LabThingsObjectData` instance attached to ``obj``. - """ - if LABTHINGS_DICT_KEY not in obj.__dict__: - obj.__dict__[LABTHINGS_DICT_KEY] = LabThingsObjectData() - return obj.__dict__[LABTHINGS_DICT_KEY] - - WrappedT = TypeVar("WrappedT") diff --git a/src/labthings_fastapi/websockets.py b/src/labthings_fastapi/websockets.py index 2d3136f2..9d5c00df 100644 --- a/src/labthings_fastapi/websockets.py +++ b/src/labthings_fastapi/websockets.py @@ -25,11 +25,14 @@ from fastapi import WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from typing import TYPE_CHECKING, Literal + +from labthings_fastapi.message_broker import Message, MessageBroker from .exceptions import PropertyNotObservableError if TYPE_CHECKING: from .thing import Thing +LOGGER = logging.getLogger(__file__) WEBTHING_ERROR_URL = "https://w3c.github.io/web-thing-protocol/errors" @@ -76,7 +79,7 @@ def observation_error_response( async def relay_notifications_to_websocket( - websocket: WebSocket, receive_stream: ObjectReceiveStream + websocket: WebSocket, receive_stream: ObjectReceiveStream[Message] ) -> None: """Relay objects from a stream to a websocket as JSON. @@ -90,11 +93,45 @@ async def relay_notifications_to_websocket( """ async with receive_stream: async for item in receive_stream: - await websocket.send_json(jsonable_encoder(item)) + if item.message_type == "action": + msg = { + "messageType": "actionStatus", + "data": {"action name": item.affordance, "status": item.payload}, + } + elif item.message_type == "property": + msg = { + "messageType": "propertyStatus", + "data": {item.affordance: jsonable_encoder(item.payload)}, + } + else: + LOGGER.error(f"Could not relay '{item}' to websocket - bad type.") + await websocket.send_json(msg) + + +def assert_property_is_observable(thing: Thing, property: str) -> bool: + """Check that a Thing has a particular property and it is observable. + + :param thing: the `~lt.Thing` instance being observed. + :param property: the name of the property. + :raises KeyError: if the property does not exist. + :raises PropertyNotObservableError: if the property isn't observable. + :returns: `True` if an exception wasn't raised. + """ + try: + prop = thing.properties[property] # raises KeyError if it doesn't exist + except KeyError: + raise + if not prop.is_observable: + msg = f"'{thing.name}.{property}' is not observable." + raise PropertyNotObservableError(msg) + return True async def process_messages_from_websocket( - websocket: WebSocket, send_stream: ObjectSendStream, thing: Thing + websocket: WebSocket, + send_stream: ObjectSendStream[Message], + broker: MessageBroker, + thing: Thing, ) -> None: r"""Process messages received from a websocket. @@ -105,6 +142,7 @@ async def process_messages_from_websocket( :param send_stream: an `anyio.abc.ObjectSendStream` that we use to register for events, i.e. data sent to that stream will be sent through this websocket, by `.relay_notifications_to_websocket`\ . + :param broker: the message broker to use for subscriptions. :param thing: the `~lt.Thing` we are attached to. The websocket is specific to one `~lt.Thing`, and this is it. """ @@ -117,20 +155,24 @@ async def process_messages_from_websocket( if data["messageType"] == "addPropertyObservation": try: for k in data["data"].keys(): - thing.observe_property(k, send_stream) + assert_property_is_observable(thing, k) + broker.subscribe(thing.name, k, send_stream) except (KeyError, PropertyNotObservableError) as e: logging.error(f"Got a bad websocket message: {data}, caused {e!r}.") - await send_stream.send(observation_error_response(k, "property", e)) + await websocket.send_json(observation_error_response(k, "property", e)) if data["messageType"] == "addActionObservation": try: for k in data["data"].keys(): - thing.observe_action(k, send_stream) + _ = thing.actions[k] # raise a KeyError if the action doesn't exist + broker.subscribe(thing.name, k, send_stream) except KeyError as e: logging.error(f"Got a bad websocket message: {data}, caused {e!r}.") - await send_stream.send(observation_error_response(k, "action", e)) + await websocket.send_json(observation_error_response(k, "action", e)) -async def websocket_endpoint(thing: Thing, websocket: WebSocket) -> None: +async def websocket_endpoint( + thing: Thing, websocket: WebSocket, broker: MessageBroker +) -> None: r"""Handle communication to a client via websocket. This function handles a websocket connection to a `~lt.Thing`\ 's websocket @@ -139,9 +181,12 @@ async def websocket_endpoint(thing: Thing, websocket: WebSocket) -> None: :param thing: the `~lt.Thing` the websocket is attached to. :param websocket: the web socket that has been created. + :param broker: the message broker to use for subscriptions. """ await websocket.accept() - send_stream, receive_stream = create_memory_object_stream[dict]() + send_stream, receive_stream = create_memory_object_stream[Message]() async with create_task_group() as tg: tg.start_soon(relay_notifications_to_websocket, websocket, receive_stream) - tg.start_soon(process_messages_from_websocket, websocket, send_stream, thing) + tg.start_soon( + process_messages_from_websocket, websocket, send_stream, broker, thing + ) diff --git a/tests/test_message_broker.py b/tests/test_message_broker.py new file mode 100644 index 00000000..61495e1d --- /dev/null +++ b/tests/test_message_broker.py @@ -0,0 +1,151 @@ +"""Test the message broker.""" + +import anyio +from anyio.abc import ObjectReceiveStream +import pytest + +from pydantic import ValidationError + +from labthings_fastapi.message_broker import Message, MessageBroker + + +class Unjsonable: + """A class that won't serialise.""" + + +@pytest.mark.parametrize( + "message", + [ + ("test_thing", "prop", "property", 42), + ("test_thing", "prop", "property", Unjsonable()), + ("test_thing", "do_it", "action", None), + ("test_thing", "notify", "event", {"key": "value"}), + ], +) +def test_message_valid(message): + """Check that Messages can be constructed.""" + amodel = Message(*message) + ARGS = ["thing", "affordance", "message_type", "payload"] + kwargs = dict(zip(ARGS, message, strict=True)) + kmodel = Message(**kwargs) + assert amodel == kmodel + assert amodel.__dict__ == kwargs + + +@pytest.mark.parametrize( + "message", + [ + ("test_thing", "prop", "custom", None), + (Unjsonable(), "prop", "property", None), + ("thing", Unjsonable(), "property", None), + ], +) +def test_message_invalid(message): + """Check that invalid Things or message types fail validation.""" + with pytest.raises(ValidationError): + _ = Message(*message) + + +def test_subscribe_unsubscribe(): + """Test that we can subscribe to affordances, and unsubscribe.""" + broker = MessageBroker() + assert broker._subscriptions == {} + + send_stream, receive_stream = anyio.create_memory_object_stream[Message]() + broker.subscribe("thing_name", "prop", send_stream) + + assert send_stream in broker._subscriptions["thing_name"]["prop"] + + broker.unsubscribe("thing_name", "prop", send_stream) + assert send_stream not in broker._subscriptions["thing_name"]["prop"] + + # There's deliberately no validation when subscribing - that must come + # from elsewhere. We do raise key errors for unsubscriptions though, if + # there's no subscription to cancel. + with pytest.raises(KeyError): + broker.unsubscribe("other_thing", "prop", send_stream) + with pytest.raises(KeyError): + broker.unsubscribe("thing_name", "other_prop", send_stream) + # There is currently no check that a subscription is current, so we don't + # yet test if the stream is currently subscribed before deleting it from the + # list of subscriptions. That means the following should work, even though + # we're not currently subscribed: + assert len(broker._subscriptions["thing_name"]["prop"]) == 0 + broker.unsubscribe("thing_name", "prop", send_stream) + assert len(broker._subscriptions["thing_name"]["prop"]) == 0 + + # We do check that the "thing" key is a string, not a `Thing` instance + # (because that's a likely mistake). + with pytest.raises(TypeError): + broker.subscribe(Unjsonable(), "whatever", send_stream) # type: ignore + with pytest.raises(TypeError): + broker.unsubscribe(Unjsonable(), "whatever", send_stream) # type: ignore + + +async def append_messages( + stream: ObjectReceiveStream[Message], + dest: list[Message], +): + """Append messages from a stream to a list.""" + async with stream: + async for item in stream: + dest.append(item) + + +def test_message_passing(): + """Check messages propagate in an event loop. + + We test messages with 0, 1, and 2 subscribers. + """ + message_a = Message("thing_a", "prop", "property", "a") + message_b = Message("thing_b", "prop", "property", "b") + message_a2 = Message("thing_a", "prop2", "property", "a2") + + broker = MessageBroker() + + async def publish_messages_and_shutdown(): + """Publish several messages.""" + await broker.publish(message_a) + await broker.publish(message_b) # not received - but no error either + await broker.publish(message_a2) + await broker.publish(message_a2) + await broker.publish(message_a2) + # It's important to close streams or the test hangs. + await broker.close_streams() + + # We make four subscriptions, defined below. + # Each has a thing name and property name. Any messages received will be + # appended to the list. + message_lists = { + "a_prop": ("thing_a", "prop", []), # message_a + "c_prop": ("thing_c", "prop", []), # no message + "a_prop2": ("thing_a", "prop2", []), # message_a3 x3 + "a_prop2_dup": ("thing_a", "prop2", []), # as above + } + + # Define the async code that runs in an event loop + async def main(): + async with anyio.create_task_group() as tg: + retain_send_streams = [] + for thing, prop, dest in message_lists.values(): + # Subscribe to messages, and handle them by + # appending to a list. + send, recv = anyio.create_memory_object_stream[Message]() + broker.subscribe(thing, prop, send) + tg.start_soon(append_messages, recv, dest) + # The line below stops the send stream getting garbage collected. + retain_send_streams.append(send) + tg.start_soon(publish_messages_and_shutdown) + + # Run the function in an event loop + anyio.run(main) + + # Check that the messages were received by the expected streams + assert message_lists["a_prop"][2] == [message_a] + assert message_lists["c_prop"][2] == [] + assert message_lists["a_prop2"][2] == [message_a2] * 3 + assert message_lists["a_prop2_dup"][2] == [message_a2] * 3 + + +if __name__ == "__main__": + test_message_passing() diff --git a/tests/test_properties.py b/tests/test_properties.py index a1806b12..5c61cdf1 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -10,7 +10,6 @@ import labthings_fastapi as lt from labthings_fastapi.exceptions import ( NotBoundToInstanceError, - ServerNotRunningError, UnsupportedConstraintError, ) from labthings_fastapi.properties import BaseProperty, PropertyInfo @@ -386,18 +385,6 @@ def test_setting_from_thread(server): assert r.json() is True -def test_setting_without_event_loop(): - """Test DataProperty raises an error if set without an event loop.""" - # This test may need to change, if we change the intended behaviour - # Currently it should never be necessary to change properties from the - # main thread, so we raise an error if you try to do so - server = lt.ThingServer.from_things({"thing": PropertyTestThing}) - thing = server.things["thing"] - assert isinstance(thing, PropertyTestThing) - with pytest.raises(ServerNotRunningError): - thing.boolprop = False # Can't call it until the event loop's running - - @pytest.mark.parametrize("prop_info", CONSTRAINED_PROPS) def test_constrained_properties(prop_info, mocker): """Test that constraints on property values generate correct models. diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 7446a079..9eff2054 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -5,6 +5,7 @@ InvocationCancelledError, ) from labthings_fastapi.testing import create_thing_without_server +from labthings_fastapi.websockets import assert_property_is_observable class ThingWithProperties(lt.Thing): @@ -85,19 +86,6 @@ def thing(): return create_thing_without_server(ThingWithProperties) -def test_observing_dataprop(thing, mocker): - """Check `observe_property` is OK on a data property. - - This checks that something is added to the set of observers. - We don't check for events, as there's no event loop: this is - tested in `test_observing_dataprop_with_ws` below. - """ - observers_set = ThingWithProperties.dataprop._observers_set(thing) - fake_observer = mocker.Mock() - thing.observe_property("dataprop", fake_observer) - assert fake_observer in observers_set - - @pytest.mark.parametrize( argnames=["name", "exception"], argvalues=[ @@ -109,10 +97,10 @@ def test_observing_dataprop(thing, mocker): ("missing", KeyError), ], ) -def test_observing_errors(thing, mocker, name, exception): +def test_observing_errors(thing, name, exception): """Check errors are raised if we observe an unsuitable property.""" with pytest.raises(exception): - thing.observe_property(name, mocker.Mock()) + assert assert_property_is_observable(thing, name) def test_observing_dataprop_with_ws(client, ws): @@ -171,27 +159,6 @@ def test_observing_dataprop_error_with_ws(ws, name, title, status): assert message["error"]["status"] == status -def test_observing_action(thing, mocker): - """Check observing an action is successful. - - This verifies we've added an observer to the set, but doesn't test for - notifications: that would require an event loop. - """ - observers_set = ThingWithProperties.increment_dataprop._observers_set(thing) - fake_observer = mocker.Mock() - thing.observe_action("increment_dataprop", fake_observer) - assert fake_observer in observers_set - - -@pytest.mark.parametrize( - "name", ["non_property", "python_property", "undecorated", "dataprop"] -) -def test_observing_action_error(thing, mocker, name): - """Check observing an attribute that's not an action raises an error.""" - with pytest.raises(KeyError): - thing.observe_action(name, mocker.Mock()) - - @pytest.mark.parametrize( argnames=["name", "final_status"], argvalues=[