Skip to content

Commit

Permalink
mypy ok
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Dec 3, 2023
1 parent 03aa35c commit 7f8ccf3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 34 deletions.
2 changes: 1 addition & 1 deletion kioss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kioss._plan import APipe as Pipe
from kioss._plan import APipe, SourcePipe as Pipe
from kioss._util import LOGGER
from kioss import _plan, _visitor
_plan.APipe.ITERATOR_GENERATING_VISITOR_CLASS = _visitor.IteratorGeneratingVisitor
4 changes: 0 additions & 4 deletions kioss/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,6 @@ def register_error_sample(error):

return samples

@classmethod
def __rshift__(cls, source: Callable[[], Iterator[T]]) -> "SourcePipe[T]":
return SourcePipe(source)

class SourcePipe(APipe[T]):
def __init__(self, source: Callable[[], Iterator[T]]):
"""
Expand Down
56 changes: 27 additions & 29 deletions tests/test_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
import unittest
from collections import Counter

from typing import List, TypeVar
from typing import Iterator, List, TypeVar

from parameterized import parameterized # type: ignore

from kioss import Pipe, _util
from kioss import APipe, Pipe, _util

TEN_MS = 0.01
DELTA = 0.35
T = TypeVar("T")


def timepipe(pipe: Pipe):
def timepipe(pipe: APipe):
def iterate():
for _ in pipe:
pass
Expand All @@ -34,13 +34,13 @@ def ten_ms_identity(x: T) -> T:


class TestPipe(unittest.TestCase):
def test_init(self):
def test_init(self) -> None:
# from iterable
self.assertListEqual(Pipe(range(8).__iter__).collect(), list(range(8)))
# from iterator
self.assertListEqual(Pipe(range(8).__iter__).collect(), list(range(8)))

def test_chain(self):
def test_chain(self) -> None:
# test that the order is preserved
self.assertListEqual(
Pipe(range(2).__iter__)
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_flatten(self, n_threads: int):

# exceptions in the middle on flattening is well catched, potential recursion issue too
class RaisesStopIterationWhenCalledForIter:
def __iter__(self):
def __iter__(self) -> None:
raise StopIteration()

def raise_for_4(x):
Expand Down Expand Up @@ -188,7 +188,7 @@ def store_error_types(error):
Pipe(lambda: map(int, "-")).flatten(n_threads=n_threads).collect, # type: ignore
)

def test_add(self):
def test_add(self) -> None:
self.assertListEqual(
list(
sum(
Expand Down Expand Up @@ -231,7 +231,7 @@ def test_map(self, n_threads: int):
{1, 3},
)

def test_map_threading_bench(self):
def test_map_threading_bench(self) -> None:
# non-threaded vs threaded execution time
pipe = Pipe(range(N).__iter__).map(ten_ms_identity)
self.assertAlmostEqual(timepipe(pipe), TEN_MS * N, delta=DELTA * (TEN_MS * N))
Expand All @@ -243,7 +243,7 @@ def test_map_threading_bench(self):
delta=DELTA * (TEN_MS * N) / n_threads,
)

def test_do(self):
def test_do(self) -> None:
l: List[int] = []

func = lambda x: x**2
Expand All @@ -266,14 +266,14 @@ def func_with_side_effect(x):
)
self.assertSetEqual(set(l), set(map(func, args)))

def test_filter(self):
def test_filter(self) -> None:
self.assertListEqual(
list(Pipe(range(8).__iter__).filter(lambda x: x % 2)), [1, 3, 5, 7]
list(Pipe(range(8).__iter__).filter(lambda x: x % 2 != 0)), [1, 3, 5, 7]
)

self.assertListEqual(list(Pipe(range(8).__iter__).filter(lambda _: False)), [])

def test_batch(self):
def test_batch(self) -> None:
self.assertListEqual(
Pipe(range(8).__iter__).batch(size=3).collect(),
[[0, 1, 2], [3, 4, 5], [6, 7]],
Expand Down Expand Up @@ -311,11 +311,6 @@ def store_errors(error):
.map(int)
.batch(2)
.catch(ValueError, when=store_errors)
.map(
lambda potential_error: [potential_error]
if isinstance(potential_error, Exception)
else potential_error
)
.map(iter)
.flatten()
.map(type)
Expand All @@ -337,7 +332,7 @@ def test_slow(self, n_threads: int):
delta=DELTA * (1 / freq * N),
)

def test_collect(self):
def test_collect(self) -> None:
self.assertListEqual(
Pipe(range(8).__iter__).collect(n_samples=6), list(range(6))
)
Expand All @@ -351,7 +346,7 @@ def test_collect(self):
delta=DELTA * TEN_MS * 8,
)

def test_time(self):
def test_time(self) -> None:
new_pipe = lambda: Pipe(range(8).__iter__).slow(64)
start_time = time.time()
new_pipe().collect()
Expand Down Expand Up @@ -439,7 +434,7 @@ def store_errors(error):
.collect,
)

def test_superintend(self):
def test_superintend(self) -> None:
self.assertListEqual(
Pipe("123".__iter__).map(int).superintend(n_samples=2), [1, 2]
)
Expand All @@ -458,7 +453,7 @@ def test_superintend(self):
lambda: superintend(raise_if_more_errors_than=0),
)

def test_log(self):
def test_log(self) -> None:
self.assertListEqual(
Pipe("123".__iter__)
.log("chars")
Expand All @@ -470,7 +465,7 @@ def test_log(self):
[[1, 2], [3]],
)

def test_partial_iteration(self):
def test_partial_iteration(self) -> None:
first_elem = next(
iter(
Pipe(([0] * N).__iter__)
Expand Down Expand Up @@ -498,9 +493,9 @@ def test_partial_iteration(self):
samples = list(itertools.islice(pipe, n))
self.assertListEqual(samples, [0] * n)

def test_invalid_source(self):
self.assertRaises(TypeError, lambda: Pipe(range(3)))
pipe_ok_at_construction = Pipe(lambda: range(3))
def test_invalid_source(self) -> None:
self.assertRaises(TypeError, lambda: Pipe(range(3))) # type: ignore
pipe_ok_at_construction: Pipe[int] = Pipe(lambda: range(3)) # type: ignore
self.assertRaises(TypeError, lambda: pipe_ok_at_construction.collect())

@parameterized.expand([[1], [2], [3]])
Expand All @@ -509,7 +504,7 @@ def test_invalid_flatten_upstream(self, n_threads: int):
TypeError, Pipe(range(3).__iter__).flatten(n_threads=n_threads).collect # type: ignore
)

def test_planning_and_execution_decoupling(self):
def test_planning_and_execution_decoupling(self) -> None:
a = Pipe(range(N).__iter__)
b = a.batch(size=N)
# test double execution
Expand All @@ -518,12 +513,15 @@ def test_planning_and_execution_decoupling(self):
# test b not affected by a execution
self.assertListEqual(b.collect(), [list(range(N))])

def test_generator_already_generating(self):
def test_generator_already_generating(self) -> None:
l: List[Iterator[int]] = [iter((ten_ms_identity(x) for x in range(N))) for _ in range(3)]
self.assertEqual(
Counter(
Pipe(
[(ten_ms_identity(x) for x in range(N)) for _ in range(3)].__iter__
).flatten(n_threads=2)
l.__iter__
)
.map(iter)
.flatten(n_threads=2)
),
Counter(list(range(N)) + list(range(N)) + list(range(N))),
)

0 comments on commit 7f8ccf3

Please sign in to comment.