diff --git a/contrib/can/jumpstarter_driver_can/client_test.py b/contrib/can/jumpstarter_driver_can/client_test.py index 519dfc65..2dd93f47 100644 --- a/contrib/can/jumpstarter_driver_can/client_test.py +++ b/contrib/can/jumpstarter_driver_can/client_test.py @@ -8,7 +8,7 @@ from jumpstarter.common.utils import serve from jumpstarter_driver_can.common import IsotpParams -from jumpstarter_driver_can.driver import Can, Isotp +from jumpstarter_driver_can.driver import Can, Isotp, IsotpSocket def test_client_can_send_recv(request): @@ -261,3 +261,28 @@ def test_client_isotp(request, blocking_send, addresses): client1.stop() client2.stop() + + +@pytest.mark.parametrize("can_fd", [False, True]) +def test_client_isotp_socket(request, can_fd): + params = IsotpParams( + max_frame_size=2048, + blocking_send=False, + can_fd=can_fd, + ) + + with ( + serve(IsotpSocket(channel="vcan0", address=isotp.Address(rxid=1, txid=2), params=params)) as client1, + serve(IsotpSocket(channel="vcan0", address=isotp.Address(rxid=2, txid=1), params=params)) as client2, + ): + client1.start() + client2.start() + + message = randbytes(params.max_frame_size) + + client1.send(message, send_timeout=10) + + assert client2.recv(block=True, timeout=10) == message + + client1.stop() + client2.stop() diff --git a/contrib/can/jumpstarter_driver_can/common.py b/contrib/can/jumpstarter_driver_can/common.py index 34171bef..08394c5f 100644 --- a/contrib/can/jumpstarter_driver_can/common.py +++ b/contrib/can/jumpstarter_driver_can/common.py @@ -57,6 +57,23 @@ class IsotpParams(BaseModel): listen_mode: bool = False blocking_send: bool = False + def apply(self, socket): + socket.set_opts( + optflag=None, + frame_txtime=None, + ext_address=None, + txpad=self.tx_padding, + rxpad=None, + rx_ext_address=None, + tx_stmin=None, + ) + socket.set_fc_opts(bs=self.blocksize, stmin=self.stmin, wftmax=self.wftmax) + socket.set_ll_opts( + mtu=isotp.socket.LinkLayerProtocol.CAN_FD if self.can_fd else isotp.socket.LinkLayerProtocol.CAN, + tx_dl=self.tx_data_length, + tx_flags=None, + ) + class IsotpMessage(BaseModel): data: Optional[Base64Bytes] diff --git a/contrib/can/jumpstarter_driver_can/driver.py b/contrib/can/jumpstarter_driver_can/driver.py index c1664a99..df7a7d3f 100644 --- a/contrib/can/jumpstarter_driver_can/driver.py +++ b/contrib/can/jumpstarter_driver_can/driver.py @@ -171,3 +171,74 @@ def stop_sending(self) -> None: @validate_call(validate_return=True) def stop_receiving(self) -> None: self.stack.stop_receiving() + + +@dataclass(kw_only=True, config=ConfigDict(arbitrary_types_allowed=True)) +class IsotpSocket(Driver): + channel: str + address: isotp.Address + params: IsotpParams = field(default_factory=IsotpParams) + + sock: isotp.socket | None = field(init=False, default=None) + + @classmethod + def client(cls) -> str: + return "jumpstarter_driver_can.client.IsotpClient" + + @export + @validate_call(validate_return=True) + def start(self) -> None: + if self.sock: + raise ValueError("socket already started") + self.sock = isotp.socket() + self.params.apply(self.sock) + self.sock.bind(self.channel, self.address) + + @export + @validate_call(validate_return=True) + def stop(self) -> None: + if not self.sock: + raise ValueError("socket not started") + self.sock.close() + self.sock = None + + @export + @validate_call(validate_return=True) + def send( + self, msg: IsotpMessage, target_address_type: int | None = None, send_timeout: float | None = None + ) -> None: + if not self.sock: + raise ValueError("socket not started") + self.sock.send(msg.data) + + @export + @validate_call(validate_return=True) + def recv(self, block: bool = False, timeout: float | None = None) -> IsotpMessage: + if not self.sock: + raise ValueError("socket not started") + return IsotpMessage.model_construct(data=self.sock.recv()) + + @export + @validate_call(validate_return=True) + def available(self) -> bool: + raise NotImplementedError + + @export + @validate_call(validate_return=True) + def transmitting(self) -> bool: + raise NotImplementedError + + @export + @validate_call(validate_return=True) + def set_address(self, address: IsotpAddress | IsotpAsymmetricAddress) -> None: + raise NotImplementedError + + @export + @validate_call(validate_return=True) + def stop_sending(self) -> None: + raise NotImplementedError + + @export + @validate_call(validate_return=True) + def stop_receiving(self) -> None: + raise NotImplementedError