Skip to content

Commit

Permalink
0.15.1: .catch: add 'when' param
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Jul 4, 2024
1 parent 13e3624 commit 6e5b3ae
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 8 deletions.
2 changes: 2 additions & 0 deletions streamable/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
def catch(
iterator: Iterator[T],
kind: Type[Exception] = Exception,
when: Callable[[Exception], Any] = bool,
finally_raise: bool = False,
) -> Iterator[T]:
return CatchingIterator(
iterator,
kind,
when,
finally_raise=finally_raise,
)

Expand Down
4 changes: 3 additions & 1 deletion streamable/iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def __init__(
self,
iterator: Iterator[T],
kind: Type[Exception],
when: Callable[[Exception], Any],
finally_raise: bool,
) -> None:
self.iterator = iterator
self.kind = kind
self.when = when
self.finally_raise = finally_raise
self._to_be_finally_raised: Optional[Exception] = None

Expand All @@ -51,7 +53,7 @@ def __next__(self) -> T:
raise exception
raise
except Exception as exception:
if isinstance(exception, self.kind):
if isinstance(exception, self.kind) and self.when(exception):
if self._to_be_finally_raised is None:
self._to_be_finally_raised = exception
continue
Expand Down
14 changes: 8 additions & 6 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,21 @@ def accept(self, visitor: "Visitor[V]") -> V:
def catch(
self,
kind: Type[Exception] = Exception,
when: Callable[[Exception], Any] = bool,
finally_raise: bool = False,
) -> "Stream[T]":
"""
Catches the upstream exceptions if they are instances of `kind`.
Catches the upstream exceptions if they are instances of `kind` and they satisfy the `when` predicate.
Args:
kind (Type[Exception], optional): The type of exceptions to catch (default is all non-exit exceptions).
kind (Type[Exception], optional): The type of exceptions to catch (default is base Exception).
when (Callable[[Exception], Any], optional): An additional condition that must be satisfied (`when(exception)` must be Truthy) to catch the exception (Always satisfied by default).
finally_raise (bool, optional): If True the first catched exception is raised when upstream's iteration ends (default is False).
Returns:
Stream[T]: A stream of upstream elements catching the eligible exceptions.
"""
return CatchStream(self, kind, finally_raise)
return CatchStream(self, kind, when, finally_raise)

def filter(self, keep: Callable[[T], Any] = bool) -> "Stream[T]":
"""
Expand Down Expand Up @@ -369,19 +371,19 @@ def __init__(
self,
upstream: Stream[T],
kind: Type[Exception],
when: Callable[[Exception], Any],
finally_raise: bool,
) -> None:
super().__init__(upstream)
self.kind = kind
self.when = when
self.finally_raise = finally_raise

def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_catch_stream(self)

def __repr__(self) -> str:
call = (
f"catch({friendly_string(self.kind)}, finally_raise={self.finally_raise})"
)
call = f"catch({friendly_string(self.kind)}, when={friendly_string(self.when)}, finally_raise={self.finally_raise})"
return f"{repr(self.upstream)}\\\n.{call}"


Expand Down
1 change: 1 addition & 0 deletions streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> Iterator[T]:
return functions.catch(
stream.upstream.accept(self),
stream.kind,
stream.when,
finally_raise=stream.finally_raise,
)

Expand Down
10 changes: 10 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,16 @@ def f(i):
):
next(iterator)

with self.assertRaises(
TypeError,
msg="`catch` does not catch if `when` not satisfied",
):
list(
Stream(map(throw, [ValueError, TypeError])).catch(
Exception, when=lambda exception: "ValueError" in repr(exception)
)
)

def test_observe(self) -> None:
value_error_rainsing_stream: Stream[List[int]] = (
Stream("123--567")
Expand Down
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.15.0"
__version__ = "0.15.1"

0 comments on commit 6e5b3ae

Please sign in to comment.