diff --git a/streamable/functions.py b/streamable/functions.py index 78e0da7..1a7d3db 100644 --- a/streamable/functions.py +++ b/streamable/functions.py @@ -36,11 +36,13 @@ def catch( iterator: Iterator[T], kind: Type[Exception] = Exception, + when: Callable[[Exception], Any] = bool, finally_raise: bool = False, ) -> Iterator[T]: return CatchingIterator( iterator, kind, + when, finally_raise=finally_raise, ) diff --git a/streamable/iters.py b/streamable/iters.py index 8ada7e7..f1bc71e 100644 --- a/streamable/iters.py +++ b/streamable/iters.py @@ -33,10 +33,12 @@ def __init__( self, iterator: Iterator[T], kind: Type[Exception], + when: Callable[[Exception], Any], finally_raise: bool, ) -> None: self.iterator = iterator self.kind = kind + self.when = when self.finally_raise = finally_raise self._to_be_finally_raised: Optional[Exception] = None @@ -51,7 +53,7 @@ def __next__(self) -> T: raise exception raise except Exception as exception: - if isinstance(exception, self.kind): + if isinstance(exception, self.kind) and self.when(exception): if self._to_be_finally_raised is None: self._to_be_finally_raised = exception continue diff --git a/streamable/stream.py b/streamable/stream.py index 513392d..272acf5 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -105,19 +105,21 @@ def accept(self, visitor: "Visitor[V]") -> V: def catch( self, kind: Type[Exception] = Exception, + when: Callable[[Exception], Any] = bool, finally_raise: bool = False, ) -> "Stream[T]": """ - Catches the upstream exceptions if they are instances of `kind`. + Catches the upstream exceptions if they are instances of `kind` and they satisfy the `when` predicate. Args: - kind (Type[Exception], optional): The type of exceptions to catch (default is all non-exit exceptions). + kind (Type[Exception], optional): The type of exceptions to catch (default is base Exception). + when (Callable[[Exception], Any], optional): An additional condition that must be satisfied (`when(exception)` must be Truthy) to catch the exception (Always satisfied by default). finally_raise (bool, optional): If True the first catched exception is raised when upstream's iteration ends (default is False). Returns: Stream[T]: A stream of upstream elements catching the eligible exceptions. """ - return CatchStream(self, kind, finally_raise) + return CatchStream(self, kind, when, finally_raise) def filter(self, keep: Callable[[T], Any] = bool) -> "Stream[T]": """ @@ -369,19 +371,19 @@ def __init__( self, upstream: Stream[T], kind: Type[Exception], + when: Callable[[Exception], Any], finally_raise: bool, ) -> None: super().__init__(upstream) self.kind = kind + self.when = when self.finally_raise = finally_raise def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_catch_stream(self) def __repr__(self) -> str: - call = ( - f"catch({friendly_string(self.kind)}, finally_raise={self.finally_raise})" - ) + call = f"catch({friendly_string(self.kind)}, when={friendly_string(self.when)}, finally_raise={self.finally_raise})" return f"{repr(self.upstream)}\\\n.{call}" diff --git a/streamable/visitors/iterator.py b/streamable/visitors/iterator.py index 44d5609..53131dd 100644 --- a/streamable/visitors/iterator.py +++ b/streamable/visitors/iterator.py @@ -26,6 +26,7 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> Iterator[T]: return functions.catch( stream.upstream.accept(self), stream.kind, + stream.when, finally_raise=stream.finally_raise, ) diff --git a/tests/test_stream.py b/tests/test_stream.py index 9044f44..33461fc 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -972,6 +972,16 @@ def f(i): ): next(iterator) + with self.assertRaises( + TypeError, + msg="`catch` does not catch if `when` not satisfied", + ): + list( + Stream(map(throw, [ValueError, TypeError])).catch( + Exception, when=lambda exception: "ValueError" in repr(exception) + ) + ) + def test_observe(self) -> None: value_error_rainsing_stream: Stream[List[int]] = ( Stream("123--567") diff --git a/version.py b/version.py index 9da2f8f..903e77c 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -__version__ = "0.15.0" +__version__ = "0.15.1"