Skip to content

Commit

Permalink
Reduce boilerplate when creating Devices (#240)
Browse files Browse the repository at this point in the history
* Reduce boilerplate in StandardReadable

This reduces the amount of duplication and repetition when adding
signals to a StandardReadable.

As part of this, classes defining the types of signal have been created,
which control the behaviour of the Signal being registered

Signals must be registered either using the "add_children_as_readables"
contextmanager, or the "add_readables" function.

set_readable_signals() will now issue a DeprecationWarning - it is
required to use the new function/context manager instead

---------

Co-authored-by: Tom C (DLS) <[email protected]>
  • Loading branch information
AlexanderWells-diamond and coretl authored Apr 25, 2024
1 parent 4c8ce61 commit 82b8a5b
Show file tree
Hide file tree
Showing 12 changed files with 515 additions and 78 deletions.
4 changes: 3 additions & 1 deletion src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from .signal_backend import SignalBackend
from .sim_signal_backend import SimSignalBackend
from .standard_readable import StandardReadable
from .standard_readable import ConfigSignal, HintedSignal, StandardReadable
from .utils import (
DEFAULT_TIMEOUT,
Callback,
Expand Down Expand Up @@ -84,6 +84,8 @@
"ShapeProvider",
"StaticDirectoryProvider",
"StandardReadable",
"ConfigSignal",
"HintedSignal",
"TriggerInfo",
"TriggerLogic",
"HardwareTriggeredFlyable",
Expand Down
5 changes: 2 additions & 3 deletions src/ophyd_async/core/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from .async_status import AsyncStatus
from .device import Device
from .signal import SignalR
from .utils import DEFAULT_TIMEOUT, merge_gathered_dicts

T = TypeVar("T")
Expand Down Expand Up @@ -161,7 +160,7 @@ def __init__(
self,
controller: DetectorControl,
writer: DetectorWriter,
config_sigs: Sequence[SignalR] = (),
config_sigs: Sequence[AsyncReadable] = (),
name: str = "",
writer_timeout: float = DEFAULT_TIMEOUT,
) -> None:
Expand Down Expand Up @@ -214,7 +213,7 @@ async def stage(self) -> None:
async def _check_config_sigs(self):
"""Checks configuration signals are named and connected."""
for signal in self._config_sigs:
if signal._name == "":
if signal.name == "":
raise Exception(
"config signal must be named before it is passed to the detector"
)
Expand Down
5 changes: 2 additions & 3 deletions src/ophyd_async/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
Location,
Movable,
Reading,
Stageable,
Subscribable,
)

from ophyd_async.protocols import AsyncReadable
from ophyd_async.protocols import AsyncReadable, AsyncStageable

from .async_status import AsyncStatus
from .device import Device
Expand Down Expand Up @@ -128,7 +127,7 @@ def set_staged(self, staged: bool):
return self._staged or bool(self._listeners)


class SignalR(Signal[T], AsyncReadable, Stageable, Subscribable):
class SignalR(Signal[T], AsyncReadable, AsyncStageable, Subscribable):
"""Signal that can be read from and monitored"""

_cache: Optional[_SignalCache] = None
Expand Down
231 changes: 209 additions & 22 deletions src/ophyd_async/core/standard_readable.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,47 @@
from typing import Dict, Sequence, Tuple
import warnings
from contextlib import contextmanager
from typing import (
Callable,
Dict,
Generator,
Optional,
Sequence,
Tuple,
Type,
Union,
)

from bluesky.protocols import Descriptor, Reading, Stageable
from bluesky.protocols import Descriptor, HasHints, Hints, Reading

from ophyd_async.protocols import AsyncConfigurable, AsyncReadable
from ophyd_async.protocols import AsyncConfigurable, AsyncReadable, AsyncStageable

from .async_status import AsyncStatus
from .device import Device
from .device import Device, DeviceVector
from .signal import SignalR
from .utils import merge_gathered_dicts

ReadableChild = Union[AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints]
ReadableChildWrapper = Union[
Callable[[ReadableChild], ReadableChild], Type["ConfigSignal"], Type["HintedSignal"]
]

class StandardReadable(Device, AsyncReadable, AsyncConfigurable, Stageable):

class StandardReadable(
Device, AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints
):
"""Device that owns its children and provides useful default behavior.
- When its name is set it renames child Devices
- Signals can be registered for read() and read_configuration()
- These signals will be subscribed for read() between stage() and unstage()
"""

_read_signals: Tuple[SignalR, ...] = ()
_configuration_signals: Tuple[SignalR, ...] = ()
_read_uncached_signals: Tuple[SignalR, ...] = ()
# These must be immutable types to avoid accidental sharing between
# different instances of the class
_readables: Tuple[AsyncReadable, ...] = ()
_configurables: Tuple[AsyncConfigurable, ...] = ()
_stageables: Tuple[AsyncStageable, ...] = ()
_has_hints: Tuple[HasHints, ...] = ()

def set_readable_signals(
self,
Expand All @@ -38,37 +59,203 @@ def set_readable_signals(
read_uncached:
Signals to make up :meth:`~StandardReadable.read` that won't be cached
"""
self._read_signals = tuple(read)
self._configuration_signals = tuple(config)
self._read_uncached_signals = tuple(read_uncached)
warnings.warn(
DeprecationWarning(
"Migrate to `add_children_as_readables` context manager or "
"`add_readables` method"
)
)
self.add_readables(read, wrapper=HintedSignal)
self.add_readables(config, wrapper=ConfigSignal)
self.add_readables(read_uncached, wrapper=HintedSignal.uncached)

@AsyncStatus.wrap
async def stage(self) -> None:
for sig in self._read_signals + self._configuration_signals:
for sig in self._stageables:
await sig.stage().task

@AsyncStatus.wrap
async def unstage(self) -> None:
for sig in self._read_signals + self._configuration_signals:
for sig in self._stageables:
await sig.unstage().task

async def describe_configuration(self) -> Dict[str, Descriptor]:
return await merge_gathered_dicts(
[sig.describe() for sig in self._configuration_signals]
[sig.describe_configuration() for sig in self._configurables]
)

async def read_configuration(self) -> Dict[str, Reading]:
return await merge_gathered_dicts(
[sig.read() for sig in self._configuration_signals]
[sig.read_configuration() for sig in self._configurables]
)

async def describe(self) -> Dict[str, Descriptor]:
return await merge_gathered_dicts(
[sig.describe() for sig in self._read_signals + self._read_uncached_signals]
)
return await merge_gathered_dicts([sig.describe() for sig in self._readables])

async def read(self) -> Dict[str, Reading]:
return await merge_gathered_dicts(
[sig.read() for sig in self._read_signals]
+ [sig.read(cached=False) for sig in self._read_uncached_signals]
)
return await merge_gathered_dicts([sig.read() for sig in self._readables])

@property
def hints(self) -> Hints:
hints: Hints = {}
for new_hint in self._has_hints:
# Merge the existing and new hints, based on the type of the value.
# This avoids default dict merge behaviour that overrides the values;
# we want to combine them when they are Sequences, and ensure they are
# identical when string values.
for key, value in new_hint.hints.items():
if isinstance(value, str):
if key in hints:
assert (
hints[key] == value # type: ignore[literal-required]
), f"Hints key {key} value may not be overridden"
else:
hints[key] = value # type: ignore[literal-required]
elif isinstance(value, Sequence):
if key in hints:
for new_val in value:
assert (
new_val not in hints[key] # type: ignore[literal-required]
), f"Hint {key} {new_val} overrides existing hint"
hints[key] = ( # type: ignore[literal-required]
hints[key] + value # type: ignore[literal-required]
)
else:
hints[key] = value # type: ignore[literal-required]
else:
raise TypeError(
f"{new_hint.name}: Unknown type for value '{value}' "
f" for key '{key}'"
)

return hints

@contextmanager
def add_children_as_readables(
self,
wrapper: Optional[ReadableChildWrapper] = None,
) -> Generator[None, None, None]:
"""Context manager to wrap adding Devices
Add Devices to this class instance inside the Context Manager to automatically
add them to the correct fields, based on the Device's interfaces.
The provided wrapper class will be applied to all Devices and can be used to
specify their behaviour.
Parameters
----------
wrapper:
Wrapper class to apply to all Devices created inside the context manager.
See Also
--------
:func:`~StandardReadable.add_readables`
:class:`ConfigSignal`
:class:`HintedSignal`
:meth:`HintedSignal.uncached`
"""

dict_copy = self.__dict__.copy()

yield

# Set symmetric difference operator gives all newly added keys
new_keys = dict_copy.keys() ^ self.__dict__.keys()
new_values = [self.__dict__[key] for key in new_keys]

flattened_values = []
for value in new_values:
if isinstance(value, DeviceVector):
children = value.children()
flattened_values.extend([x[1] for x in children])
else:
flattened_values.append(value)

new_devices = list(filter(lambda x: isinstance(x, Device), flattened_values))
self.add_readables(new_devices, wrapper)

def add_readables(
self,
devices: Sequence[Device],
wrapper: Optional[ReadableChildWrapper] = None,
) -> None:
"""Add the given devices to the lists of known Devices
Add the provided Devices to the relevant fields, based on the Signal's
interfaces.
The provided wrapper class will be applied to all Devices and can be used to
specify their behaviour.
Parameters
----------
devices:
The devices to be added
wrapper:
Wrapper class to apply to all Devices created inside the context manager.
See Also
--------
:func:`~StandardReadable.add_children_as_readables`
:class:`ConfigSignal`
:class:`HintedSignal`
:meth:`HintedSignal.uncached`
"""

for readable in devices:
obj = readable
if wrapper:
obj = wrapper(readable)

if isinstance(obj, AsyncReadable):
self._readables += (obj,)

if isinstance(obj, AsyncConfigurable):
self._configurables += (obj,)

if isinstance(obj, AsyncStageable):
self._stageables += (obj,)

if isinstance(obj, HasHints):
self._has_hints += (obj,)


class ConfigSignal(AsyncConfigurable):
def __init__(self, signal: ReadableChild) -> None:
assert isinstance(signal, SignalR), f"Expected signal, got {signal}"
self.signal = signal

async def read_configuration(self) -> Dict[str, Reading]:
return await self.signal.read()

async def describe_configuration(self) -> Dict[str, Descriptor]:
return await self.signal.describe()


class HintedSignal(HasHints, AsyncReadable):
def __init__(self, signal: ReadableChild, allow_cache: bool = True) -> None:
assert isinstance(signal, SignalR), f"Expected signal, got {signal}"
self.signal = signal
self.cached = None if allow_cache else allow_cache
if allow_cache:
self.stage = signal.stage
self.unstage = signal.unstage

async def read(self) -> Dict[str, Reading]:
return await self.signal.read(cached=self.cached)

async def describe(self) -> Dict[str, Descriptor]:
return await self.signal.describe()

@property
def name(self) -> str:
return self.signal.name

@property
def hints(self) -> Hints:
return {"fields": [self.signal.name]}

@classmethod
def uncached(cls, signal: ReadableChild) -> "HintedSignal":
return cls(signal, allow_cache=False)
20 changes: 14 additions & 6 deletions src/ophyd_async/epics/areadetector/single_trigger_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

from bluesky.protocols import Triggerable

from ophyd_async.core import AsyncStatus, SignalR, StandardReadable
from ophyd_async.core import (
AsyncStatus,
ConfigSignal,
HintedSignal,
SignalR,
StandardReadable,
)

from .drivers.ad_base import ADBase
from .utils import ImageMode
Expand All @@ -20,12 +26,14 @@ def __init__(
) -> None:
self.drv = drv
self.__dict__.update(plugins)
self.set_readable_signals(
# Can't subscribe to read signals as race between monitor coming back and
# caput callback on acquire
read_uncached=[self.drv.array_counter] + list(read_uncached),
config=[self.drv.acquire_time],

self.add_readables(
[self.drv.array_counter, *read_uncached],
wrapper=HintedSignal.uncached,
)

self.add_readables([self.drv.acquire_time], wrapper=ConfigSignal)

super().__init__(name=name)

@AsyncStatus.wrap
Expand Down
Loading

0 comments on commit 82b8a5b

Please sign in to comment.