-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #62 from jumpstarter-dev/higher-order-streams
Higher order streams
- Loading branch information
Showing
48 changed files
with
1,952 additions
and
735 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from functools import cached_property | ||
from typing import Callable, List, Optional, Sequence, Tuple | ||
from uuid import UUID | ||
|
||
import can | ||
from can.bus import _SelfRemovingCyclicTask | ||
from pydantic import ConfigDict, validate_call | ||
|
||
from jumpstarter.client import DriverClient | ||
|
||
from .common import CanMessage | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class RemoteCyclicSendTask(can.broadcastmanager.CyclicSendTaskABC): | ||
client: CanClient | ||
uuid: UUID | ||
|
||
def stop(self) -> None: | ||
self.client.call("_stop_task", self.uuid) | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class CanClient(DriverClient, can.BusABC): | ||
def __post_init__(self): | ||
self._periodic_tasks: List[_SelfRemovingCyclicTask] = [] | ||
self._filters = None | ||
self._is_shutdown: bool = False | ||
|
||
super().__post_init__() | ||
|
||
@property | ||
@validate_call(validate_return=True) | ||
def state(self) -> can.BusState: | ||
return self.call("state") | ||
|
||
@state.setter | ||
@validate_call(validate_return=True) | ||
def state(self, value: can.BusState) -> None: | ||
self.call("state", value) | ||
|
||
@cached_property | ||
@validate_call(validate_return=True) | ||
def channel_info(self) -> str: | ||
return self.call("channel_info") | ||
|
||
@cached_property | ||
@validate_call(validate_return=True) | ||
def protocol(self) -> can.CanProtocol: | ||
return self.call("protocol") | ||
|
||
@validate_call(validate_return=True, config=ConfigDict(arbitrary_types_allowed=True)) | ||
def _recv_internal(self, timeout: Optional[float]) -> Tuple[Optional[can.Message], bool]: | ||
msg, filtered = self.call("_recv_internal", timeout) | ||
if msg: | ||
return can.Message(**CanMessage.model_validate(msg).__dict__), filtered | ||
return None, filtered | ||
|
||
@validate_call(validate_return=True, config=ConfigDict(arbitrary_types_allowed=True)) | ||
def send(self, msg: can.Message, timeout: Optional[float] = None) -> None: | ||
self.call("send", CanMessage.construct(msg), timeout) | ||
|
||
@validate_call(validate_return=True, config=ConfigDict(arbitrary_types_allowed=True)) | ||
def _send_periodic_internal( | ||
self, | ||
msgs: Sequence[can.Message], | ||
period: float, | ||
duration: Optional[float] = None, | ||
modifier_callback: Optional[Callable[[can.Message], None]] = None, | ||
) -> can.broadcastmanager.CyclicSendTaskABC: | ||
if modifier_callback: | ||
return super()._send_periodic_internal(msgs, period, duration, modifier_callback) | ||
else: | ||
msgs = [CanMessage.construct(msg) for msg in msgs] | ||
return RemoteCyclicSendTask(client=self, uuid=self.call("_send_periodic_internal", msgs, period, duration)) | ||
|
||
# python-can bug | ||
# https://docs.pydantic.dev/2.8/errors/usage_errors/#typed-dict-version | ||
# @validate_call(validate_return=True) | ||
def _apply_filters(self, filters: Optional[can.typechecking.CanFilters]) -> None: | ||
self.call("_apply_filters", filters) | ||
|
||
@validate_call(validate_return=True) | ||
def flush_tx_buffer(self) -> None: | ||
self.call("flush_tx_buffer") | ||
|
||
@validate_call(validate_return=True) | ||
def shutdown(self) -> None: | ||
self.call("shutdown") | ||
super().shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
from itertools import islice | ||
from random import randbytes | ||
from threading import Semaphore | ||
|
||
import can | ||
import isotp | ||
import pytest | ||
|
||
from jumpstarter.common.utils import serve | ||
from jumpstarter_driver_can.driver import Can | ||
|
||
|
||
def test_client_can_send_recv(request): | ||
with ( | ||
serve(Can(channel=request.node.name, interface="virtual")) as client1, | ||
serve(Can(channel=request.node.name, interface="virtual")) as client2, | ||
client1, | ||
client2, | ||
): | ||
client1.send(can.Message(data=b"hello")) | ||
|
||
assert client2.recv().data == b"hello" | ||
|
||
with pytest.raises(NotImplementedError): | ||
# not implemented on virtual bus | ||
client1.flush_tx_buffer() | ||
|
||
|
||
def test_client_can_property(request): | ||
driver = Can(channel=request.node.name, interface="virtual") | ||
with serve(driver) as client, client: | ||
assert client.channel_info == driver.bus.channel_info | ||
assert client.state == driver.bus.state | ||
assert client.protocol == driver.bus.protocol | ||
|
||
with pytest.raises(NotImplementedError): | ||
# not implemented on virtual bus | ||
client.state = can.BusState.PASSIVE | ||
|
||
|
||
def test_client_can_iterator(request): | ||
with ( | ||
serve(Can(channel=request.node.name, interface="virtual")) as client1, | ||
serve(Can(channel=request.node.name, interface="virtual")) as client2, | ||
client1, | ||
client2, | ||
): | ||
client1.send(can.Message(data=b"a")) | ||
client1.send(can.Message(data=b"b")) | ||
client1.send(can.Message(data=b"c")) | ||
|
||
assert [msg.data for msg in islice(client2, 3)] == [b"a", b"b", b"c"] | ||
|
||
|
||
def test_client_can_filter(request): | ||
with ( | ||
serve(Can(channel=request.node.name, interface="virtual")) as client1, | ||
serve(Can(channel=request.node.name, interface="virtual")) as client2, | ||
client1, | ||
client2, | ||
): | ||
client2.set_filters([{"can_id": 0x1, "can_mask": 0x1, "extended": True}]) | ||
|
||
client1.send(can.Message(arbitration_id=0x0, data=b"a")) | ||
client1.send(can.Message(arbitration_id=0x1, data=b"b")) | ||
client1.send(can.Message(arbitration_id=0x2, data=b"c")) | ||
|
||
assert client2.recv().data == b"b" | ||
|
||
|
||
def test_client_can_notifier(request): | ||
with ( | ||
serve(Can(channel=request.node.name, interface="virtual")) as client1, | ||
serve(Can(channel=request.node.name, interface="virtual")) as client2, | ||
client1, | ||
client2, | ||
): | ||
sem = Semaphore(0) | ||
|
||
def listener(msg): | ||
assert msg.data == b"hello" | ||
sem.release() | ||
|
||
notifier = can.Notifier(client2, [listener]) | ||
|
||
client1.send(can.Message(data=b"hello")) | ||
|
||
sem.acquire() | ||
notifier.stop() | ||
|
||
|
||
def test_client_can_redirect(request): | ||
with ( | ||
serve(Can(channel=request.node.name, interface="virtual")) as client1, | ||
serve(Can(channel=request.node.name, interface="virtual")) as client2, | ||
client1, | ||
client2, | ||
): | ||
bus3 = can.interface.Bus(request.node.name + "_inner", interface="virtual") | ||
bus4 = can.interface.Bus(request.node.name + "_inner", interface="virtual") | ||
|
||
notifier = can.Notifier(client2, [can.RedirectReader(bus3)]) | ||
|
||
client1.send(can.Message(data=b"hello")) | ||
|
||
assert bus4.recv().data == b"hello" | ||
|
||
notifier.stop() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"msgs, expected", | ||
[ | ||
([can.Message(data=b"a"), can.Message(data=b"b")], [(1, b"a"), (1, b"b"), (1, b"a"), (1, b"b")]), | ||
(can.Message(data=b"a"), [(1, b"a"), (1, b"a"), (1, b"a"), (1, b"a")]), | ||
], | ||
) | ||
def test_client_can_send_periodic_local(request, msgs, expected): | ||
with ( | ||
serve(Can(channel=request.node.name, interface="virtual")) as client1, | ||
serve(Can(channel=request.node.name, interface="virtual")) as client2, | ||
client1, | ||
client2, | ||
): | ||
|
||
def modifier_callback(msg): | ||
msg.arbitration_id = 1 | ||
|
||
client1.send_periodic( | ||
msgs=msgs, | ||
period=0.1, | ||
duration=1, | ||
store_task=True, | ||
modifier_callback=modifier_callback, | ||
) | ||
|
||
assert [(msg.arbitration_id, msg.data) for msg in islice(client2, 4)] == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"msgs, expected", | ||
[ | ||
([can.Message(data=b"a"), can.Message(data=b"b")], [(0, b"a"), (0, b"b"), (0, b"a"), (0, b"b")]), | ||
(can.Message(data=b"a"), [(0, b"a"), (0, b"a"), (0, b"a"), (0, b"a")]), | ||
], | ||
) | ||
def test_client_can_send_periodic_remote(request, msgs, expected): | ||
with ( | ||
serve(Can(channel=request.node.name, interface="virtual")) as client1, | ||
serve(Can(channel=request.node.name, interface="virtual")) as client2, | ||
client1, | ||
client2, | ||
): | ||
client1.send_periodic( | ||
msgs=msgs, | ||
period=0.1, | ||
duration=1, | ||
store_task=True, | ||
) | ||
|
||
assert [(msg.arbitration_id, msg.data) for msg in islice(client2, 4)] == expected | ||
|
||
|
||
@pytest.mark.parametrize("tx_data_length", [8, 64]) | ||
@pytest.mark.parametrize("blocking_send", [False, True]) | ||
def test_client_can_isotp(request, tx_data_length, blocking_send): | ||
with ( | ||
serve(Can(channel=request.node.name, interface="virtual")) as client1, | ||
serve(Can(channel=request.node.name, interface="virtual")) as client2, | ||
client1, | ||
client2, | ||
): | ||
notifier1 = can.Notifier(client1, []) | ||
notifier2 = can.Notifier(client2, []) | ||
|
||
params = { | ||
"max_frame_size": 2048, | ||
"tx_data_length": tx_data_length, | ||
"blocking_send": blocking_send, | ||
} | ||
|
||
transport1 = isotp.NotifierBasedCanStack( | ||
client1, | ||
notifier1, | ||
address=isotp.Address(rxid=1, txid=2), | ||
params=params, | ||
) | ||
transport2 = isotp.NotifierBasedCanStack( | ||
client2, | ||
notifier2, | ||
address=isotp.Address(rxid=2, txid=1), | ||
params=params, | ||
) | ||
|
||
transport1.start() | ||
transport2.start() | ||
|
||
message = randbytes(params["max_frame_size"]) | ||
|
||
transport1.send(message, send_timeout=10) | ||
assert transport2.recv(block=True, timeout=10) == message | ||
|
||
transport1.stop() | ||
transport2.stop() | ||
|
||
notifier1.stop() | ||
notifier2.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Optional | ||
|
||
from pydantic import Base64Bytes, BaseModel | ||
|
||
|
||
class CanMessage(BaseModel): | ||
timestamp: float | ||
arbitration_id: int | ||
is_extended_id: bool | ||
is_remote_frame: bool | ||
is_error_frame: bool | ||
channel: Optional[int | str] | ||
dlc: Optional[int] | ||
data: Optional[Base64Bytes] | ||
is_fd: bool | ||
is_rx: bool | ||
bitrate_switch: bool | ||
error_state_indicator: bool | ||
|
||
@classmethod | ||
def construct(cls, msg): | ||
return cls.model_construct( | ||
timestamp=msg.timestamp, | ||
arbitration_id=msg.arbitration_id, | ||
is_extended_id=msg.is_extended_id, | ||
is_remote_frame=msg.is_remote_frame, | ||
is_error_frame=msg.is_error_frame, | ||
channel=msg.channel, | ||
dlc=msg.dlc, | ||
data=msg.data, | ||
is_fd=msg.is_fd, | ||
is_rx=msg.is_rx, | ||
bitrate_switch=msg.bitrate_switch, | ||
error_state_indicator=msg.error_state_indicator, | ||
) |
Oops, something went wrong.