Skip to content

Commit

Permalink
Fix remaining test
Browse files Browse the repository at this point in the history
  • Loading branch information
coretl committed Nov 12, 2024
1 parent ffe4afb commit 6ff13d8
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 36 deletions.
8 changes: 7 additions & 1 deletion src/ophyd_async/core/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@ def create_children_from_annotations(self, device: Device):

async def connect_mock(self, device: Device, mock: LazyMock):
# Connect serially, no errors to gather up as in mock mode
exceptions: dict[str, Exception] = {}
for name, child_device in device.children():
await child_device.connect(mock=mock.child(name))
try:
await child_device.connect(mock=mock.child(name))
except Exception as e:
exceptions[name] = e
if exceptions:
raise NotConnected.with_other_exceptions_logged(exceptions)

async def connect_real(self, device: Device, timeout: float, force_reconnect: bool):
"""Used during ``Device.connect``.
Expand Down
36 changes: 17 additions & 19 deletions src/ophyd_async/core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,10 @@

import asyncio
import logging
from collections.abc import Awaitable, Callable, Iterable, Sequence
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass
from enum import Enum, EnumMeta
from typing import (
Any,
Generic,
Literal,
ParamSpec,
TypeVar,
get_args,
get_origin,
)
from typing import Any, Generic, Literal, ParamSpec, TypeVar, get_args, get_origin
from unittest.mock import Mock

import numpy as np
Expand All @@ -22,7 +14,7 @@
P = ParamSpec("P")
Callback = Callable[[T], None]
DEFAULT_TIMEOUT = 10.0
ErrorText = str | dict[str, Exception]
ErrorText = str | Mapping[str, Exception]


class StrictEnum(str, Enum):
Expand Down Expand Up @@ -100,6 +92,19 @@ def format_error_string(self, indent="") -> str:
def __str__(self) -> str:
return self.format_error_string(indent="")

@classmethod
def with_other_exceptions_logged(
cls, exceptions: Mapping[str, Exception]
) -> NotConnected:
for name, exception in exceptions.items():
if not isinstance(exception, NotConnected):
logging.exception(
f"device `{name}` raised unexpected exception "
f"{type(exception).__name__}",
exc_info=exception,
)
return NotConnected(exceptions)


@dataclass(frozen=True)
class WatcherUpdate(Generic[T]):
Expand Down Expand Up @@ -137,14 +142,7 @@ async def wait_for_connection(**coros: Awaitable[None]):
exceptions[name] = result

if exceptions:
for name, exception in exceptions.items():
if not isinstance(exception, NotConnected):
logging.exception(
f"device `{name}` raised unexpected exception "
f"{type(exception).__name__}",
exc_info=exception,
)
raise NotConnected(exceptions)
raise NotConnected.with_other_exceptions_logged(exceptions)


def get_dtype(datatype: type) -> np.dtype:
Expand Down
25 changes: 10 additions & 15 deletions tests/core/test_device_save_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from unittest.mock import patch

import numpy as np
import numpy.typing as npt
import pytest
import yaml
from bluesky.run_engine import RunEngine
Expand Down Expand Up @@ -71,17 +70,16 @@ def __init__(self, name: str):
self.pv_str: SignalRW = epics_signal_rw(str, "PV2")
self.pv_enum_str: SignalRW = epics_signal_rw(MyEnum, "PV3")
self.pv_enum: SignalRW = epics_signal_rw(MyEnum, "PV4")
self.pv_array_int8 = epics_signal_rw(npt.NDArray[np.int8], "PV5")
self.pv_array_uint8 = epics_signal_rw(npt.NDArray[np.uint8], "PV6")
self.pv_array_int16 = epics_signal_rw(npt.NDArray[np.int16], "PV7")
self.pv_array_uint16 = epics_signal_rw(npt.NDArray[np.uint16], "PV8")
self.pv_array_int32 = epics_signal_rw(npt.NDArray[np.int32], "PV9")
self.pv_array_uint32 = epics_signal_rw(npt.NDArray[np.uint32], "PV10")
self.pv_array_int64 = epics_signal_rw(npt.NDArray[np.int64], "PV11")
self.pv_array_uint64 = epics_signal_rw(npt.NDArray[np.uint64], "PV12")
self.pv_array_float32 = epics_signal_rw(npt.NDArray[np.float32], "PV13")
self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14")
self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15")
self.pv_array_int8 = epics_signal_rw(Array1D[np.int8], "PV5")
self.pv_array_uint8 = epics_signal_rw(Array1D[np.uint8], "PV6")
self.pv_array_int16 = epics_signal_rw(Array1D[np.int16], "PV7")
self.pv_array_uint16 = epics_signal_rw(Array1D[np.uint16], "PV8")
self.pv_array_int32 = epics_signal_rw(Array1D[np.int32], "PV9")
self.pv_array_uint32 = epics_signal_rw(Array1D[np.uint32], "PV10")
self.pv_array_int64 = epics_signal_rw(Array1D[np.int64], "PV11")
self.pv_array_uint64 = epics_signal_rw(Array1D[np.uint64], "PV12")
self.pv_array_float32 = epics_signal_rw(Array1D[np.float32], "PV13")
self.pv_array_float64 = epics_signal_rw(Array1D[np.float64], "PV14")
self.pv_array_str = epics_signal_rw(Sequence[str], "PV16")
self.pv_protocol_device_abstraction = epics_signal_rw(Table, "pva://PV17")
super().__init__(name)
Expand Down Expand Up @@ -168,9 +166,6 @@ async def test_save_device_all_types(
)

await pv.set(data)
await device_all_types.pv_array_npstr.set(
np.array(["one", "two", "three"], dtype=np.str_),
)
await device_all_types.pv_array_str.set(
["one", "two", "three"],
)
Expand Down
1 change: 0 additions & 1 deletion tests/test_data/test_yaml_save.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
pv_array_int32: [-2147483648, 2147483647, 0, 1, 2, 3, 4]
pv_array_int64: [-9223372036854775808, 9223372036854775807, 0, 1, 2, 3, 4]
pv_array_int8: [-128, 127, 0, 1, 2, 3, 4]
pv_array_npstr: [one, two, three]
pv_array_str:
- one
- two
Expand Down

0 comments on commit 6ff13d8

Please sign in to comment.