Skip to content

Commit

Permalink
Speed up device creation and connection in mock mode (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
coretl authored Nov 11, 2024
1 parent f8fae43 commit 3d9f508
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 245 deletions.
2 changes: 2 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
DEFAULT_TIMEOUT,
CalculatableTimeout,
Callback,
LazyMock,
NotConnected,
Reference,
StrictEnum,
Expand Down Expand Up @@ -176,6 +177,7 @@
"DEFAULT_TIMEOUT",
"CalculatableTimeout",
"Callback",
"LazyMock",
"CALCULATE_TIMEOUT",
"NotConnected",
"Reference",
Expand Down
119 changes: 70 additions & 49 deletions src/ophyd_async/core/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
import asyncio
import sys
from collections.abc import Coroutine, Iterator, Mapping, MutableMapping
from functools import cached_property
from logging import LoggerAdapter, getLogger
from typing import Any, TypeVar
from unittest.mock import Mock

from bluesky.protocols import HasName
from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop

from ._protocol import Connectable
from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection

_device_mocks: dict[Device, Mock] = {}
from ._utils import DEFAULT_TIMEOUT, LazyMock, NotConnected, wait_for_connection


class DeviceConnector:
Expand All @@ -37,25 +35,23 @@ def create_children_from_annotations(self, device: Device):
during ``__init__``.
"""

async def connect(
self,
device: Device,
mock: bool | Mock,
timeout: float,
force_reconnect: bool,
):
async def connect_mock(self, device: Device, mock: LazyMock):
# Connect serially, no errors to gather up as in mock mode
for name, child_device in device.children():
await child_device.connect(mock=mock.child(name))

async def connect_real(self, device: Device, timeout: float, force_reconnect: bool):
"""Used during ``Device.connect``.
This is called when a previous connect has not been done, or has been
done in a different mock more. It should connect the Device and all its
children.
"""
coros = {}
for name, child_device in device.children():
child_mock = getattr(mock, name) if mock else mock # Mock() or False
coros[name] = child_device.connect(
mock=child_mock, timeout=timeout, force_reconnect=force_reconnect
)
# Connect in parallel, gathering up NotConnected errors
coros = {
name: child_device.connect(timeout=timeout, force_reconnect=force_reconnect)
for name, child_device in device.children()
}
await wait_for_connection(**coros)


Expand All @@ -67,9 +63,8 @@ class Device(HasName, Connectable):
parent: Device | None = None
# None if connect hasn't started, a Task if it has
_connect_task: asyncio.Task | None = None
# If not None, then this is the mock arg of the previous connect
# to let us know if we can reuse an existing connection
_connect_mock_arg: bool | None = None
# The mock if we have connected in mock mode
_mock: LazyMock | None = None

def __init__(
self, name: str = "", connector: DeviceConnector | None = None
Expand All @@ -83,10 +78,18 @@ def name(self) -> str:
"""Return the name of the Device"""
return self._name

@cached_property
def _child_devices(self) -> dict[str, Device]:
return {}

def children(self) -> Iterator[tuple[str, Device]]:
for attr_name, attr in self.__dict__.items():
if attr_name != "parent" and isinstance(attr, Device):
yield attr_name, attr
yield from self._child_devices.items()

@cached_property
def log(self) -> LoggerAdapter:
return LoggerAdapter(
getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name}
)

def set_name(self, name: str):
"""Set ``self.name=name`` and each ``self.child.name=name+"-child"``.
Expand All @@ -97,28 +100,33 @@ def set_name(self, name: str):
New name to set
"""
self._name = name
# Ensure self.log is recreated after a name change
self.log = LoggerAdapter(
getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name}
)
# Ensure logger is recreated after a name change
if "log" in self.__dict__:
del self.log
for child_name, child in self.children():
child_name = f"{self.name}-{child_name.strip('_')}" if self.name else ""
child.set_name(child_name)

def __setattr__(self, name: str, value: Any) -> None:
# Bear in mind that this function is called *a lot*, so
# we need to make sure nothing expensive happens in it...
if name == "parent":
if self.parent not in (value, None):
raise TypeError(
f"Cannot set the parent of {self} to be {value}: "
f"it is already a child of {self.parent}"
)
elif isinstance(value, Device):
# ...hence not doing an isinstance check for attributes we
# know not to be Devices
elif name not in _not_device_attrs and isinstance(value, Device):
value.parent = self
return super().__setattr__(name, value)
self._child_devices[name] = value
# ...and avoiding the super call as we know it resolves to `object`
return object.__setattr__(self, name, value)

async def connect(
self,
mock: bool | Mock = False,
mock: bool | LazyMock = False,
timeout: float = DEFAULT_TIMEOUT,
force_reconnect: bool = False,
) -> None:
Expand All @@ -133,26 +141,39 @@ async def connect(
timeout:
Time to wait before failing with a TimeoutError.
"""
uses_mock = bool(mock)
can_use_previous_connect = (
uses_mock is self._connect_mock_arg
and self._connect_task
and not (self._connect_task.done() and self._connect_task.exception())
)
if mock is True:
mock = Mock() # create a new Mock if one not provided
if force_reconnect or not can_use_previous_connect:
self._connect_mock_arg = uses_mock
if self._connect_mock_arg:
_device_mocks[self] = mock
coro = self._connector.connect(
device=self, mock=mock, timeout=timeout, force_reconnect=force_reconnect
if mock:
# Always connect in mock mode serially
if isinstance(mock, LazyMock):
# Use the provided mock
self._mock = mock
elif not self._mock:
# Make one
self._mock = LazyMock()
await self._connector.connect_mock(self, self._mock)
else:
# Try to cache the connect in real mode
can_use_previous_connect = (
self._mock is None
and self._connect_task
and not (self._connect_task.done() and self._connect_task.exception())
)
self._connect_task = asyncio.create_task(coro)

assert self._connect_task, "Connect task not created, this shouldn't happen"
# Wait for it to complete
await self._connect_task
if force_reconnect or not can_use_previous_connect:
self._mock = None
coro = self._connector.connect_real(self, timeout, force_reconnect)
self._connect_task = asyncio.create_task(coro)
assert self._connect_task, "Connect task not created, this shouldn't happen"
# Wait for it to complete
await self._connect_task


_not_device_attrs = {
"_name",
"_children",
"_connector",
"_timeout",
"_mock",
"_connect_task",
}


DeviceT = TypeVar("DeviceT", bound=Device)
Expand Down
17 changes: 10 additions & 7 deletions src/ophyd_async/core/_mock_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio
from collections.abc import Callable
from functools import cached_property
from unittest.mock import AsyncMock, Mock
from unittest.mock import AsyncMock

from bluesky.protocols import Descriptor, Reading

from ._signal_backend import SignalBackend, SignalDatatypeT
from ._soft_signal_backend import SoftSignalBackend
from ._utils import Callback
from ._utils import Callback, LazyMock


class MockSignalBackend(SignalBackend[SignalDatatypeT]):
Expand All @@ -16,7 +16,7 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]):
def __init__(
self,
initial_backend: SignalBackend[SignalDatatypeT],
mock: Mock,
mock: LazyMock,
) -> None:
if isinstance(initial_backend, MockSignalBackend):
raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackend")
Expand All @@ -34,19 +34,22 @@ def __init__(

# use existing Mock if provided
self.mock = mock
self.put_mock = AsyncMock(name="put", spec=Callable)
self.mock.attach_mock(self.put_mock, "put")

super().__init__(datatype=self.initial_backend.datatype)

@cached_property
def put_mock(self) -> AsyncMock:
put_mock = AsyncMock(name="put", spec=Callable)
self.mock().attach_mock(put_mock, "put")
return put_mock

def set_value(self, value: SignalDatatypeT):
self.soft_backend.set_value(value)

def source(self, name: str, read: bool) -> str:
return f"mock+{self.initial_backend.source(name, read)}"

async def connect(self, timeout: float) -> None:
pass
raise RuntimeError("It is not possible to connect a MockSignalBackend")

@cached_property
def put_proceeds(self) -> asyncio.Event:
Expand Down
25 changes: 14 additions & 11 deletions src/ophyd_async/core/_mock_signal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@
from contextlib import asynccontextmanager, contextmanager
from unittest.mock import AsyncMock, Mock

from ._device import Device, _device_mocks
from ._device import Device
from ._mock_signal_backend import MockSignalBackend
from ._signal import Signal, SignalR, _mock_signal_backends
from ._signal import Signal, SignalConnector, SignalR
from ._soft_signal_backend import SignalDatatypeT
from ._utils import LazyMock


def get_mock(device: Device | Signal) -> Mock:
mock = device._mock # noqa: SLF001
assert isinstance(mock, LazyMock), f"Device {device} not connected in mock mode"
return mock()


def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend:
assert (
signal in _mock_signal_backends
connector = signal._connector # noqa: SLF001
assert isinstance(connector, SignalConnector), f"Expected Signal, got {signal}"
assert isinstance(
connector.backend, MockSignalBackend
), f"Signal {signal} not connected in mock mode"
return _mock_signal_backends[signal]
return connector.backend


def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT):
Expand Down Expand Up @@ -45,12 +54,6 @@ def get_mock_put(signal: Signal) -> AsyncMock:
return _get_mock_signal_backend(signal).put_mock


def get_mock(device: Device | Signal) -> Mock:
if isinstance(device, Signal):
return _get_mock_signal_backend(device).mock
return _device_mocks[device]


def reset_mock_put_calls(signal: Signal):
backend = _get_mock_signal_backend(signal)
backend.put_mock.reset_mock()
Expand Down
46 changes: 22 additions & 24 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import functools
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from typing import Any, Generic, cast
from unittest.mock import Mock

from bluesky.protocols import (
Locatable,
Expand All @@ -30,9 +29,14 @@
)
from ._soft_signal_backend import SoftSignalBackend
from ._status import AsyncStatus
from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T

_mock_signal_backends: dict[Device, MockSignalBackend] = {}
from ._utils import (
CALCULATE_TIMEOUT,
DEFAULT_TIMEOUT,
CalculatableTimeout,
Callback,
LazyMock,
T,
)


async def _wait_for(coro: Awaitable[T], timeout: float | None, source: str) -> T:
Expand All @@ -54,26 +58,28 @@ class SignalConnector(DeviceConnector):
def __init__(self, backend: SignalBackend):
self.backend = self._init_backend = backend

async def connect(
self,
device: Device,
mock: bool | Mock,
timeout: float,
force_reconnect: bool,
):
if mock:
self.backend = MockSignalBackend(self._init_backend, mock)
_mock_signal_backends[device] = self.backend
else:
self.backend = self._init_backend
async def connect_mock(self, device: Device, mock: LazyMock):
self.backend = MockSignalBackend(self._init_backend, mock)

async def connect_real(self, device: Device, timeout: float, force_reconnect: bool):
self.backend = self._init_backend
device.log.debug(f"Connecting to {self.backend.source(device.name, read=True)}")
await self.backend.connect(timeout)


class _ChildrenNotAllowed(dict[str, Device]):
def __setitem__(self, key: str, value: Device) -> None:
raise AttributeError(
f"Cannot add Device or Signal child {key}={value} of Signal, "
"make a subclass of Device instead"
)


class Signal(Device, Generic[SignalDatatypeT]):
"""A Device with the concept of a value, with R, RW, W and X flavours"""

_connector: SignalConnector
_child_devices = _ChildrenNotAllowed() # type: ignore

def __init__(
self,
Expand All @@ -89,14 +95,6 @@ def source(self) -> str:
"""Like ca://PV_PREFIX:SIGNAL, or "" if not set"""
return self._connector.backend.source(self.name, read=True)

def __setattr__(self, name: str, value: Any) -> None:
if name != "parent" and isinstance(value, Device):
raise AttributeError(
f"Cannot add Device or Signal {value} as a child of Signal {self}, "
"make a subclass of Device instead"
)
return super().__setattr__(name, value)


class _SignalCache(Generic[SignalDatatypeT]):
def __init__(self, backend: SignalBackend[SignalDatatypeT], signal: Signal):
Expand Down
Loading

0 comments on commit 3d9f508

Please sign in to comment.