Skip to content

Commit

Permalink
Merge pull request #62 from jumpstarter-dev/higher-order-streams
Browse files Browse the repository at this point in the history
Higher order streams
  • Loading branch information
kirkbrauer authored Sep 4, 2024
2 parents 1b0b8e1 + 0252313 commit 3069e09
Show file tree
Hide file tree
Showing 48 changed files with 1,952 additions and 735 deletions.
Empty file added contrib/can/README.md
Empty file.
Empty file.
93 changes: 93 additions & 0 deletions contrib/can/jumpstarter_driver_can/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from dataclasses import dataclass
from functools import cached_property
from typing import Callable, List, Optional, Sequence, Tuple
from uuid import UUID

import can
from can.bus import _SelfRemovingCyclicTask
from pydantic import ConfigDict, validate_call

from jumpstarter.client import DriverClient

from .common import CanMessage


@dataclass(kw_only=True)
class RemoteCyclicSendTask(can.broadcastmanager.CyclicSendTaskABC):
client: CanClient
uuid: UUID

def stop(self) -> None:
self.client.call("_stop_task", self.uuid)


@dataclass(kw_only=True)
class CanClient(DriverClient, can.BusABC):
def __post_init__(self):
self._periodic_tasks: List[_SelfRemovingCyclicTask] = []
self._filters = None
self._is_shutdown: bool = False

super().__post_init__()

@property
@validate_call(validate_return=True)
def state(self) -> can.BusState:
return self.call("state")

@state.setter
@validate_call(validate_return=True)
def state(self, value: can.BusState) -> None:
self.call("state", value)

@cached_property
@validate_call(validate_return=True)
def channel_info(self) -> str:
return self.call("channel_info")

@cached_property
@validate_call(validate_return=True)
def protocol(self) -> can.CanProtocol:
return self.call("protocol")

@validate_call(validate_return=True, config=ConfigDict(arbitrary_types_allowed=True))
def _recv_internal(self, timeout: Optional[float]) -> Tuple[Optional[can.Message], bool]:
msg, filtered = self.call("_recv_internal", timeout)
if msg:
return can.Message(**CanMessage.model_validate(msg).__dict__), filtered
return None, filtered

@validate_call(validate_return=True, config=ConfigDict(arbitrary_types_allowed=True))
def send(self, msg: can.Message, timeout: Optional[float] = None) -> None:
self.call("send", CanMessage.construct(msg), timeout)

@validate_call(validate_return=True, config=ConfigDict(arbitrary_types_allowed=True))
def _send_periodic_internal(
self,
msgs: Sequence[can.Message],
period: float,
duration: Optional[float] = None,
modifier_callback: Optional[Callable[[can.Message], None]] = None,
) -> can.broadcastmanager.CyclicSendTaskABC:
if modifier_callback:
return super()._send_periodic_internal(msgs, period, duration, modifier_callback)
else:
msgs = [CanMessage.construct(msg) for msg in msgs]
return RemoteCyclicSendTask(client=self, uuid=self.call("_send_periodic_internal", msgs, period, duration))

# python-can bug
# https://docs.pydantic.dev/2.8/errors/usage_errors/#typed-dict-version
# @validate_call(validate_return=True)
def _apply_filters(self, filters: Optional[can.typechecking.CanFilters]) -> None:
self.call("_apply_filters", filters)

@validate_call(validate_return=True)
def flush_tx_buffer(self) -> None:
self.call("flush_tx_buffer")

@validate_call(validate_return=True)
def shutdown(self) -> None:
self.call("shutdown")
super().shutdown()
207 changes: 207 additions & 0 deletions contrib/can/jumpstarter_driver_can/client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from itertools import islice
from random import randbytes
from threading import Semaphore

import can
import isotp
import pytest

from jumpstarter.common.utils import serve
from jumpstarter_driver_can.driver import Can


def test_client_can_send_recv(request):
with (
serve(Can(channel=request.node.name, interface="virtual")) as client1,
serve(Can(channel=request.node.name, interface="virtual")) as client2,
client1,
client2,
):
client1.send(can.Message(data=b"hello"))

assert client2.recv().data == b"hello"

with pytest.raises(NotImplementedError):
# not implemented on virtual bus
client1.flush_tx_buffer()


def test_client_can_property(request):
driver = Can(channel=request.node.name, interface="virtual")
with serve(driver) as client, client:
assert client.channel_info == driver.bus.channel_info
assert client.state == driver.bus.state
assert client.protocol == driver.bus.protocol

with pytest.raises(NotImplementedError):
# not implemented on virtual bus
client.state = can.BusState.PASSIVE


def test_client_can_iterator(request):
with (
serve(Can(channel=request.node.name, interface="virtual")) as client1,
serve(Can(channel=request.node.name, interface="virtual")) as client2,
client1,
client2,
):
client1.send(can.Message(data=b"a"))
client1.send(can.Message(data=b"b"))
client1.send(can.Message(data=b"c"))

assert [msg.data for msg in islice(client2, 3)] == [b"a", b"b", b"c"]


def test_client_can_filter(request):
with (
serve(Can(channel=request.node.name, interface="virtual")) as client1,
serve(Can(channel=request.node.name, interface="virtual")) as client2,
client1,
client2,
):
client2.set_filters([{"can_id": 0x1, "can_mask": 0x1, "extended": True}])

client1.send(can.Message(arbitration_id=0x0, data=b"a"))
client1.send(can.Message(arbitration_id=0x1, data=b"b"))
client1.send(can.Message(arbitration_id=0x2, data=b"c"))

assert client2.recv().data == b"b"


def test_client_can_notifier(request):
with (
serve(Can(channel=request.node.name, interface="virtual")) as client1,
serve(Can(channel=request.node.name, interface="virtual")) as client2,
client1,
client2,
):
sem = Semaphore(0)

def listener(msg):
assert msg.data == b"hello"
sem.release()

notifier = can.Notifier(client2, [listener])

client1.send(can.Message(data=b"hello"))

sem.acquire()
notifier.stop()


def test_client_can_redirect(request):
with (
serve(Can(channel=request.node.name, interface="virtual")) as client1,
serve(Can(channel=request.node.name, interface="virtual")) as client2,
client1,
client2,
):
bus3 = can.interface.Bus(request.node.name + "_inner", interface="virtual")
bus4 = can.interface.Bus(request.node.name + "_inner", interface="virtual")

notifier = can.Notifier(client2, [can.RedirectReader(bus3)])

client1.send(can.Message(data=b"hello"))

assert bus4.recv().data == b"hello"

notifier.stop()


@pytest.mark.parametrize(
"msgs, expected",
[
([can.Message(data=b"a"), can.Message(data=b"b")], [(1, b"a"), (1, b"b"), (1, b"a"), (1, b"b")]),
(can.Message(data=b"a"), [(1, b"a"), (1, b"a"), (1, b"a"), (1, b"a")]),
],
)
def test_client_can_send_periodic_local(request, msgs, expected):
with (
serve(Can(channel=request.node.name, interface="virtual")) as client1,
serve(Can(channel=request.node.name, interface="virtual")) as client2,
client1,
client2,
):

def modifier_callback(msg):
msg.arbitration_id = 1

client1.send_periodic(
msgs=msgs,
period=0.1,
duration=1,
store_task=True,
modifier_callback=modifier_callback,
)

assert [(msg.arbitration_id, msg.data) for msg in islice(client2, 4)] == expected


@pytest.mark.parametrize(
"msgs, expected",
[
([can.Message(data=b"a"), can.Message(data=b"b")], [(0, b"a"), (0, b"b"), (0, b"a"), (0, b"b")]),
(can.Message(data=b"a"), [(0, b"a"), (0, b"a"), (0, b"a"), (0, b"a")]),
],
)
def test_client_can_send_periodic_remote(request, msgs, expected):
with (
serve(Can(channel=request.node.name, interface="virtual")) as client1,
serve(Can(channel=request.node.name, interface="virtual")) as client2,
client1,
client2,
):
client1.send_periodic(
msgs=msgs,
period=0.1,
duration=1,
store_task=True,
)

assert [(msg.arbitration_id, msg.data) for msg in islice(client2, 4)] == expected


@pytest.mark.parametrize("tx_data_length", [8, 64])
@pytest.mark.parametrize("blocking_send", [False, True])
def test_client_can_isotp(request, tx_data_length, blocking_send):
with (
serve(Can(channel=request.node.name, interface="virtual")) as client1,
serve(Can(channel=request.node.name, interface="virtual")) as client2,
client1,
client2,
):
notifier1 = can.Notifier(client1, [])
notifier2 = can.Notifier(client2, [])

params = {
"max_frame_size": 2048,
"tx_data_length": tx_data_length,
"blocking_send": blocking_send,
}

transport1 = isotp.NotifierBasedCanStack(
client1,
notifier1,
address=isotp.Address(rxid=1, txid=2),
params=params,
)
transport2 = isotp.NotifierBasedCanStack(
client2,
notifier2,
address=isotp.Address(rxid=2, txid=1),
params=params,
)

transport1.start()
transport2.start()

message = randbytes(params["max_frame_size"])

transport1.send(message, send_timeout=10)
assert transport2.recv(block=True, timeout=10) == message

transport1.stop()
transport2.stop()

notifier1.stop()
notifier2.stop()
35 changes: 35 additions & 0 deletions contrib/can/jumpstarter_driver_can/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from pydantic import Base64Bytes, BaseModel


class CanMessage(BaseModel):
timestamp: float
arbitration_id: int
is_extended_id: bool
is_remote_frame: bool
is_error_frame: bool
channel: Optional[int | str]
dlc: Optional[int]
data: Optional[Base64Bytes]
is_fd: bool
is_rx: bool
bitrate_switch: bool
error_state_indicator: bool

@classmethod
def construct(cls, msg):
return cls.model_construct(
timestamp=msg.timestamp,
arbitration_id=msg.arbitration_id,
is_extended_id=msg.is_extended_id,
is_remote_frame=msg.is_remote_frame,
is_error_frame=msg.is_error_frame,
channel=msg.channel,
dlc=msg.dlc,
data=msg.data,
is_fd=msg.is_fd,
is_rx=msg.is_rx,
bitrate_switch=msg.bitrate_switch,
error_state_indicator=msg.error_state_indicator,
)
Loading

0 comments on commit 3069e09

Please sign in to comment.