Skip to content

Commit

Permalink
0.4.5: fix flatten infinite recursion when TypeError catched
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Oct 18, 2023
1 parent 0f9c02f commit 26304f3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
20 changes: 14 additions & 6 deletions kioss/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,24 @@ def __init__(self, iterator: Iterator[Iterator[R]]) -> None:
super().__init__(iterator)
self.current_iterator_elem = iter([])

@staticmethod
def _sanitize_input(expected_iterator_elem):
if not isinstance(expected_iterator_elem, Iterator):
raise TypeError(
f"Flattened elements must be iterators, but got {type(expected_iterator_elem)}"
)
return expected_iterator_elem

def __next__(self) -> R:
try:
return next(self.current_iterator_elem)
return next(
FlatteningIteratorWrapper._sanitize_input(self.current_iterator_elem)
)
except StopIteration:
while True:
self.current_iterator_elem = super().__next__()
if not isinstance(self.current_iterator_elem, Iterator):
raise TypeError(
f"Flattened elements must be iterators, but got {type(self.current_iterator_elem)}"
)
self.current_iterator_elem = FlatteningIteratorWrapper._sanitize_input(
super().__next__()
)
try:
return next(self.current_iterator_elem)
except StopIteration:
Expand Down
12 changes: 10 additions & 2 deletions tests/test_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,16 @@ def test_partial_iteration(self):
def test_invalid_source(self):
self.assertRaises(TypeError, lambda: Pipe(range(3)))

def test_invalid_flatten_upstream(self):
self.assertRaises(TypeError, Pipe(lambda: range(3)).flatten().collect)
@parameterized.expand([[1], [2], [3]])
def test_invalid_flatten_upstream(self, n_threads: int):
self.assertEqual(
Pipe(lambda: range(3))
.flatten(n_threads=n_threads)
.catch(TypeError) # important to check potential infinite recursion
.map(type)
.collect(),
[TypeError] * 3,
)

def test_planning_and_execution_decoupling(self):
a = Pipe(lambda: iter(range(N)))
Expand Down

0 comments on commit 26304f3

Please sign in to comment.