Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support filtering the messages on a receiver #303

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

- **Experimental**: `Pipe`, which provides a pipe between two channels, by connecting a `Receiver` to a `Sender`.

- `Receiver`s now have a `filter` method that applies a filter function on the messages on a receiver.

## Bug Fixes

<!-- Here goes notable bug fixes that are worth a special mention or explanation -->
114 changes: 112 additions & 2 deletions src/frequenz/channels/_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,25 @@ def map(
"""
return _Mapper(receiver=self, mapping_function=mapping_function)

def filter(
self, filter_function: Callable[[ReceiverMessageT_co], bool], /
) -> Receiver[ReceiverMessageT_co]:
"""Apply a filter function on the messages on a receiver.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.

Args:
filter_function: The function to be applied on incoming messages.

Returns:
A new receiver that applies the function on the received messages.
"""
return _Filter(receiver=self, filter_function=filter_function)


class ReceiverError(Error, Generic[ReceiverMessageT_co]):
"""An error that originated in a [Receiver][frequenz.channels.Receiver].
Expand Down Expand Up @@ -336,9 +355,100 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502
) # pylint: disable=protected-access

def __str__(self) -> str:
"""Return a string representation of the timer."""
"""Return a string representation of the mapper."""
return f"{type(self).__name__}:{self._receiver}:{self._mapping_function}"

def __repr__(self) -> str:
"""Return a string representation of the timer."""
"""Return a string representation of the mapper."""
return f"{type(self).__name__}({self._receiver!r}, {self._mapping_function!r})"


class _Sentinel:
"""A sentinel object to represent no value received yet."""

def __str__(self) -> str:
"""Return a string representation of this sentinel."""
return "<No message ready to be consumed>"

def __repr__(self) -> str:
"""Return a string representation of this sentinel."""
return "<No message ready to be consumed>"


_SENTINEL = _Sentinel()


class _Filter(Receiver[ReceiverMessageT_co], Generic[ReceiverMessageT_co]):
"""Apply a filter function on the messages on a receiver."""

def __init__(
self,
*,
receiver: Receiver[ReceiverMessageT_co],
filter_function: Callable[[ReceiverMessageT_co], bool],
) -> None:
"""Initialize this receiver filter.

Args:
receiver: The input receiver.
filter_function: The function to apply on the input data.
"""
self._receiver: Receiver[ReceiverMessageT_co] = receiver
"""The input receiver."""

self._filter_function: Callable[[ReceiverMessageT_co], bool] = filter_function
"""The function to apply on the input data."""

self._next_message: ReceiverMessageT_co | _Sentinel = _SENTINEL

self._recv_closed = False

async def ready(self) -> bool:
"""Wait until the receiver is ready with a message or an error.

Once a call to `ready()` has finished, the message should be read with
a call to `consume()` (`receive()` or iterated over). The receiver will
remain ready (this method will return immediately) until it is
consumed.

Returns:
Whether the receiver is still active.
"""
while await self._receiver.ready():
message = self._receiver.consume()
if self._filter_function(message):
self._next_message = message
return True
self._recv_closed = True
return False

def consume(self) -> ReceiverMessageT_co:
"""Return a transformed message once `ready()` is complete.

Returns:
The next message that was received.

Raises:
ReceiverStoppedError: If the receiver stopped producing messages.
ReceiverError: If there is a problem with the receiver.
"""
if self._recv_closed:
raise ReceiverStoppedError(self)
assert not isinstance(
self._next_message, _Sentinel
), "`consume()` must be preceded by a call to `ready()`"

message = self._next_message
self._next_message = _SENTINEL
return message
llucax marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self) -> str:
"""Return a string representation of the filter."""
return f"{type(self).__name__}:{self._receiver}:{self._filter_function}"
llucax marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self) -> str:
"""Return a string representation of the filter."""
return (
f"<{type(self).__name__} receiver={self._receiver!r} "
f"filter={self._filter_function!r} next_message={self._next_message!r}>"
)
17 changes: 17 additions & 0 deletions tests/test_anycast.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,20 @@ async def test_anycast_map() -> None:

assert (await receiver.receive()) is False
assert (await receiver.receive()) is True


async def test_anycast_filter() -> None:
"""Ensure filter keeps only the messages that pass the filter."""
chan = Anycast[int](name="input-chan")
sender = chan.new_sender()

# filter out all numbers less than 10.
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)

await sender.send(8)
await sender.send(12)
await sender.send(5)
await sender.send(15)

assert (await receiver.receive()) == 12
assert (await receiver.receive()) == 15
17 changes: 17 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,23 @@ async def test_broadcast_map() -> None:
assert (await receiver.receive()) is True


async def test_broadcast_filter() -> None:
"""Ensure filter keeps only the messages that pass the filter."""
chan = Broadcast[int](name="input-chan")
sender = chan.new_sender()

# filter out all numbers less than 10.
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)

await sender.send(8)
await sender.send(12)
await sender.send(5)
await sender.send(15)

assert (await receiver.receive()) == 12
assert (await receiver.receive()) == 15


async def test_broadcast_receiver_drop() -> None:
"""Ensure deleted receivers get cleaned up."""
chan = Broadcast[int](name="input-chan")
Expand Down