Skip to content

Commit

Permalink
superintend: add raise_when_more_errors_than; catch: replace ignore b…
Browse files Browse the repository at this point in the history
…y when
  • Loading branch information
ebonnal committed Nov 30, 2023
1 parent 65aee0c commit e018b28
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 83 deletions.
14 changes: 9 additions & 5 deletions kioss/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from datetime import datetime
from typing import (
Callable,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -158,19 +159,22 @@ def __next__(self) -> List[T]:

class CatchingIteratorWrapper(IteratorWrapper[T]):
def __init__(
self, iterator: Iterator[T], classes: Tuple[Type[Exception]], ignore: bool
self,
iterator: Iterator[T],
classes: Tuple[Type[Exception]],
when: Optional[Callable[[Exception], bool]],
) -> None:
super().__init__(iterator)
self.classes = classes
self.ignore = ignore
self.when = when

def __next__(self) -> T:
try:
return next(self.iterator)
except StopIteration:
raise
except self.classes as e:
if self.ignore:
except self.classes as exception:
if self.when is None or self.when(exception):
return next(self) # TODO fix recursion issue
else:
return e
raise exception
63 changes: 42 additions & 21 deletions kioss/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,22 @@ def slow(self, freq: float) -> "APipe[T]":
"""
return SlowPipe[T](self, freq)

def catch(self, *classes: Type[Exception], ignore=False) -> "APipe[T]":
def catch(
self,
*classes: Type[Exception],
when: Optional[Callable[[Exception], bool]] = None,
) -> "APipe[T]":
"""
Any error whose class is exception_class or a subclass of it will be catched and yielded.
Args:
exception_class (Type[Exception]): The class of exceptions to catch
ingore (bool): If True then the encountered exception_class errors will be skipped.
classes (Type[Exception]): The class of exceptions to catch
when (Callable[[Exception], bool], optional): catches an exception whose type is in `classes` only if this predicate function is None or evaluates to True.
Returns:
Pipe[T]: A new Pipe instance with error handling capability.
"""
return CatchPipe[T](self, classes, ignore)
return CatchPipe[T](self, classes, when)

def log(self, what: str = "elements") -> "APipe[T]":
"""
Expand All @@ -181,40 +185,54 @@ def collect(self, n_samples: int = float("inf")) -> List[T]:
"""
return [elem for i, elem in enumerate(self) if i < n_samples]

def superintend(self, n_samples: int = 0, n_error_samples: int = 8) -> List[T]:
def superintend(
self,
n_samples: int = 0,
n_error_samples: int = 8,
raise_when_more_errors_than: int = 0,
) -> List[T]:
"""
Superintend the Pipe: iterate over the pipe until it is exhausted and raise a RuntimeError if any exceptions occur during iteration.
Superintend the Pipe:
- iterates over it until it is exhausted,
- logs
- catches exceptions log a sample of them at the end of the iteration
- raises the first encountered error if more exception than `raise_when_more_errors_than` are catched during iteration.
- else returns a sample of the output elements
Args:
n_samples (int, optional): The maximum number of elements to collect in the list (default is infinity).
n_error_samples (int, optional): The maximum number of error samples to log (default is 8).
raise_when_more_errors_than (int, optional): An error will be raised if the number of encountered errors is more than this threshold (default is 0).
Returns:
List[T]: A list containing the elements of the Pipe truncate to the first `n_samples` ones.
Raises:
RuntimeError: If any exception is catched during iteration.
RuntimeError: If more exception than `raise_when_more_errors_than` are catched during iteration.
"""
if not isinstance(self, LogPipe):
plan = self.log("output elements")
else:
plan = self
error_samples: List[Exception] = []
samples = (
plan.catch(Exception, ignore=False)
.do(
lambda elem: error_samples.append(elem)
if isinstance(elem, Exception) and len(error_samples) < n_error_samples
else None
)
.filter(lambda elem: not isinstance(elem, Exception))
.collect(n_samples=n_samples)
errors_count = 0

def register_error_sample(error):
nonlocal errors_count
errors_count += 1
if len(error_samples) < n_error_samples:
error_samples.append(error)
return True

samples = plan.catch(Exception, when=register_error_sample).collect(
n_samples=n_samples
)
if len(error_samples):
if errors_count > 0:
logging.error(
"first %s error samples: %s\nWill now raise the first of them:",
n_error_samples,
list(map(repr, error_samples)),
)
raise error_samples[0]
if raise_when_more_errors_than < errors_count:
raise error_samples[0]

return samples

Expand Down Expand Up @@ -321,15 +339,18 @@ def __iter__(self) -> Iterator[T]:

class CatchPipe(APipe[T]):
def __init__(
self, upstream: APipe[T], classes: Tuple[Type[Exception]], ignore: bool
self,
upstream: APipe[T],
classes: Tuple[Type[Exception]],
when: Optional[Callable[[Exception], bool]],
):
super().__init__(upstream)
self.classes = classes
self.ignore = ignore
self.when = when

def __iter__(self) -> Iterator[T]:
return _exec.CatchingIteratorWrapper(
iter(self.upstream), self.classes, self.ignore
iter(self.upstream), self.classes, self.when
)


Expand Down
114 changes: 57 additions & 57 deletions tests/test_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from kioss import Pipe, _util

TEN_MS = 0.01
DELTA = 0.3
DELTA = 0.35
T = TypeVar("T")


Expand Down Expand Up @@ -162,12 +162,19 @@ def raise_for_4(x):
.map(iter)
.flatten(n_threads=n_threads)
)
error_types = set()

def store_error_types(error):
error_types.add(type(error))
return True

set(get_pipe().catch(Exception, when=store_error_types))
self.assertSetEqual(
set(get_pipe().catch(Exception, ignore=False).map(type)),
{int, ValueError, TypeError, AssertionError, RuntimeError},
error_types,
{ValueError, TypeError, AssertionError, RuntimeError},
)
self.assertSetEqual(
set(get_pipe().catch(Exception, ignore=True)),
set(get_pipe().catch(Exception)),
set(range(7)),
)

Expand Down Expand Up @@ -206,7 +213,7 @@ def test_map(self, n_threads: int):
.map(ten_ms_identity, n_threads=n_threads)
.map(lambda x: x if 1 / x else x)
.map(func, n_threads=n_threads)
.catch(ZeroDivisionError, ignore=True)
.catch(ZeroDivisionError)
.map(
ten_ms_identity, n_threads=n_threads
) # check that the ZeroDivisionError is bypass the call to func
Expand All @@ -218,7 +225,7 @@ def test_map(self, n_threads: int):
Pipe([[1], [], [3]].__iter__)
.map(iter)
.map(next, n_threads=n_threads)
.catch(RuntimeError, ignore=True)
.catch(RuntimeError)
),
{1, 3},
)
Expand Down Expand Up @@ -285,26 +292,24 @@ def test_batch(self):
)
# assert batch gracefully yields if next elem throw exception
self.assertListEqual(
Pipe("01234-56789".__iter__)
.map(int)
.batch(2)
.catch(ValueError, ignore=True)
.collect(),
Pipe("01234-56789".__iter__).map(int).batch(2).catch(ValueError).collect(),
[[0, 1], [2, 3], [4], [5, 6], [7, 8], [9]],
)
self.assertListEqual(
Pipe("0123-56789".__iter__)
.map(int)
.batch(2)
.catch(ValueError, ignore=True)
.collect(),
Pipe("0123-56789".__iter__).map(int).batch(2).catch(ValueError).collect(),
[[0, 1], [2, 3], [5, 6], [7, 8], [9]],
)
errors = set()

def store_errors(error):
errors.add(error)
return True

self.assertListEqual(
Pipe("0123-56789".__iter__)
.map(int)
.batch(2)
.catch(ValueError, ignore=False)
.catch(ValueError, when=store_errors)
.map(
lambda potential_error: [potential_error]
if isinstance(potential_error, Exception)
Expand All @@ -314,8 +319,10 @@ def test_batch(self):
.flatten()
.map(type)
.collect(),
[int, int, int, int, ValueError, int, int, int, int, int],
[int, int, int, int, int, int, int, int, int],
)
self.assertEqual(len(errors), 1)
self.assertIsInstance(next(iter(errors)), ValueError)

@parameterized.expand([[1], [2], [3]])
def test_slow(self, n_threads: int):
Expand Down Expand Up @@ -355,57 +362,45 @@ def test_time(self):
@parameterized.expand([[1], [2], [3]])
def test_catch(self, n_threads: int):
# ignore = True
errors = set()

def store_errors(error):
errors.add(error)
return True

self.assertSetEqual(
set(
Pipe(["1", "r", "2"].__iter__)
.map(int, n_threads=n_threads)
.catch(Exception, ignore=False)
.catch(Exception, when=store_errors)
.map(type)
),
{int, ValueError, int},
{int},
)
self.assertEqual(len(errors), 1)
self.assertIsInstance(next(iter(errors)), ValueError)

# ignore = False
self.assertSetEqual(
set(
Pipe(["1", "r", "2"].__iter__)
.map(int, n_threads=n_threads)
.catch(Exception)
.map(type)
),
{int, ValueError, int},
)
self.assertSetEqual(
set(
self.assertListEqual(
list(
Pipe(["1", "r", "2"].__iter__)
.map(int, n_threads=n_threads)
.catch(ValueError)
.map(type)
),
{int, ValueError, int},
[int, int],
)
# chain catches
self.assertSetEqual(
set(
self.assertListEqual(
list(
Pipe(["1", "r", "2"].__iter__)
.map(int, n_threads=n_threads)
.catch(TypeError)
.catch(ValueError)
.catch(TypeError)
.map(type)
),
{int, ValueError, int},
)
self.assertDictEqual(
dict(
Counter(
Pipe(["1", "r", "2"].__iter__)
.map(int, n_threads=n_threads)
.catch(ValueError)
.map(type) # , n_threads=n_threads)
.collect()
)
),
dict(Counter([int, ValueError, int])),
[int, int],
)

# raises
Expand All @@ -427,12 +422,22 @@ def test_catch(self, n_threads: int):
)

def test_superintend(self):
self.assertListEqual(
Pipe("123".__iter__).map(int).superintend(n_samples=2), [1, 2]
)

# errors
superintend = Pipe("12-3".__iter__).map(int).superintend
self.assertRaises(
ValueError,
Pipe("12-3".__iter__).map(int).superintend,
superintend,
)
self.assertListEqual(
Pipe("123".__iter__).map(int).superintend(n_samples=2), [1, 2]
# does not raise with sufficient threshold
superintend(raise_when_more_errors_than=1)
# raise with insufficient threshold
self.assertRaises(
ValueError,
lambda: superintend(raise_when_more_errors_than=0),
)

def test_log(self):
Expand Down Expand Up @@ -482,13 +487,8 @@ def test_invalid_source(self):

@parameterized.expand([[1], [2], [3]])
def test_invalid_flatten_upstream(self, n_threads: int):
self.assertEqual(
Pipe(range(3).__iter__)
.flatten(n_threads=n_threads)
.catch(TypeError) # important to check potential infinite recursion
.map(type)
.collect(),
[TypeError] * 3,
self.assertRaises(
TypeError, Pipe(range(3).__iter__).flatten(n_threads=n_threads).collect
)

def test_planning_and_execution_decoupling(self):
Expand Down

0 comments on commit e018b28

Please sign in to comment.