Skip to content

Commit

Permalink
Support filtering the messages on a receiver (#303)
Browse files Browse the repository at this point in the history
A new `filter` method is added to the `Receiver` interface, which allows
the application of a filter function on the messages on a receiver.

Example:

```python
async for message in receiver.filter(lambda num: num % 2):
    print(f"An even number: {message}")
```
  • Loading branch information
shsms authored Jul 3, 2024
2 parents c7e6096 + 068b04e commit 779e807
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 2 deletions.
2 changes: 2 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

- **Experimental**: `Pipe`, which provides a pipe between two channels, by connecting a `Receiver` to a `Sender`.

- `Receiver`s now have a `filter` method that applies a filter function on the messages on a receiver.

## Bug Fixes

<!-- Here goes notable bug fixes that are worth a special mention or explanation -->
114 changes: 112 additions & 2 deletions src/frequenz/channels/_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,25 @@ def map(
"""
return _Mapper(receiver=self, mapping_function=mapping_function)

def filter(
self, filter_function: Callable[[ReceiverMessageT_co], bool], /
) -> Receiver[ReceiverMessageT_co]:
"""Apply a filter function on the messages on a receiver.
Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.
Args:
filter_function: The function to be applied on incoming messages.
Returns:
A new receiver that applies the function on the received messages.
"""
return _Filter(receiver=self, filter_function=filter_function)


class ReceiverError(Error, Generic[ReceiverMessageT_co]):
"""An error that originated in a [Receiver][frequenz.channels.Receiver].
Expand Down Expand Up @@ -336,9 +355,100 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502
) # pylint: disable=protected-access

def __str__(self) -> str:
"""Return a string representation of the timer."""
"""Return a string representation of the mapper."""
return f"{type(self).__name__}:{self._receiver}:{self._mapping_function}"

def __repr__(self) -> str:
"""Return a string representation of the timer."""
"""Return a string representation of the mapper."""
return f"{type(self).__name__}({self._receiver!r}, {self._mapping_function!r})"


class _Sentinel:
"""A sentinel object to represent no value received yet."""

def __str__(self) -> str:
"""Return a string representation of this sentinel."""
return "<No message ready to be consumed>"

def __repr__(self) -> str:
"""Return a string representation of this sentinel."""
return "<No message ready to be consumed>"


_SENTINEL = _Sentinel()


class _Filter(Receiver[ReceiverMessageT_co], Generic[ReceiverMessageT_co]):
"""Apply a filter function on the messages on a receiver."""

def __init__(
self,
*,
receiver: Receiver[ReceiverMessageT_co],
filter_function: Callable[[ReceiverMessageT_co], bool],
) -> None:
"""Initialize this receiver filter.
Args:
receiver: The input receiver.
filter_function: The function to apply on the input data.
"""
self._receiver: Receiver[ReceiverMessageT_co] = receiver
"""The input receiver."""

self._filter_function: Callable[[ReceiverMessageT_co], bool] = filter_function
"""The function to apply on the input data."""

self._next_message: ReceiverMessageT_co | _Sentinel = _SENTINEL

self._recv_closed = False

async def ready(self) -> bool:
"""Wait until the receiver is ready with a message or an error.
Once a call to `ready()` has finished, the message should be read with
a call to `consume()` (`receive()` or iterated over). The receiver will
remain ready (this method will return immediately) until it is
consumed.
Returns:
Whether the receiver is still active.
"""
while await self._receiver.ready():
message = self._receiver.consume()
if self._filter_function(message):
self._next_message = message
return True
self._recv_closed = True
return False

def consume(self) -> ReceiverMessageT_co:
"""Return a transformed message once `ready()` is complete.
Returns:
The next message that was received.
Raises:
ReceiverStoppedError: If the receiver stopped producing messages.
ReceiverError: If there is a problem with the receiver.
"""
if self._recv_closed:
raise ReceiverStoppedError(self)
assert not isinstance(
self._next_message, _Sentinel
), "`consume()` must be preceded by a call to `ready()`"

message = self._next_message
self._next_message = _SENTINEL
return message

def __str__(self) -> str:
"""Return a string representation of the filter."""
return f"{type(self).__name__}:{self._receiver}:{self._filter_function}"

def __repr__(self) -> str:
"""Return a string representation of the filter."""
return (
f"<{type(self).__name__} receiver={self._receiver!r} "
f"filter={self._filter_function!r} next_message={self._next_message!r}>"
)
17 changes: 17 additions & 0 deletions tests/test_anycast.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,20 @@ async def test_anycast_map() -> None:

assert (await receiver.receive()) is False
assert (await receiver.receive()) is True


async def test_anycast_filter() -> None:
"""Ensure filter keeps only the messages that pass the filter."""
chan = Anycast[int](name="input-chan")
sender = chan.new_sender()

# filter out all numbers less than 10.
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)

await sender.send(8)
await sender.send(12)
await sender.send(5)
await sender.send(15)

assert (await receiver.receive()) == 12
assert (await receiver.receive()) == 15
17 changes: 17 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,23 @@ async def test_broadcast_map() -> None:
assert (await receiver.receive()) is True


async def test_broadcast_filter() -> None:
"""Ensure filter keeps only the messages that pass the filter."""
chan = Broadcast[int](name="input-chan")
sender = chan.new_sender()

# filter out all numbers less than 10.
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)

await sender.send(8)
await sender.send(12)
await sender.send(5)
await sender.send(15)

assert (await receiver.receive()) == 12
assert (await receiver.receive()) == 15


async def test_broadcast_receiver_drop() -> None:
"""Ensure deleted receivers get cleaned up."""
chan = Broadcast[int](name="input-chan")
Expand Down

0 comments on commit 779e807

Please sign in to comment.