Skip to content

Commit

Permalink
v1.3.0: .map/.foreach: set executor type `via: Literal["thread", …
Browse files Browse the repository at this point in the history
…"process"]` (#33)

* `.map`/`.foreach`: set executor type `via: Literal["thread", "process"]`

* validationtools: add return type hints

* 1.3.0: `.map`/`.foreach`: set executor type `via: Literal["thread", "process"]`
  • Loading branch information
ebonnal authored Oct 16, 2024
1 parent 6aba16c commit f1077bb
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 47 deletions.
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,12 @@ assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur']

### process-based concurrency

> Set `via_processes`:
> Set `via="process"`:
```python
if __name__ == "__main__":
state: List[int] = []
n_integers: int = (
integers
.map(state.append, concurrency=4, via_processes=True)
.count()
)
n_integers: int = integers.map(state.append, concurrency=4, via="process").count()
assert n_integers == 10
assert state == [] # main process's state not mutated
```
Expand Down Expand Up @@ -199,7 +195,7 @@ assert list(self_printing_integers) == list(integers) # triggers the printing
### process-based concurrency

> Like `.map` it has an optional `via_processes` parameter.
> Like for `.map`, set the parameter `via="process"`.
### async-based concurrency

Expand Down
8 changes: 6 additions & 2 deletions streamable/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
import datetime
from contextlib import suppress
from typing import (
Any,
Callable,
Expand All @@ -13,6 +14,9 @@
cast,
)

with suppress(ImportError):
from typing import Literal

from streamable.iters import (
AsyncConcurrentMappingIterable,
ByKeyGroupingIterator,
Expand Down Expand Up @@ -105,7 +109,7 @@ def map(
iterator: Iterator[T],
concurrency: int = 1,
ordered: bool = True,
via_processes: bool = False,
via: "Literal['thread', 'process']" = "thread",
) -> Iterator[U]:
validate_iterator(iterator)
validate_concurrency(concurrency)
Expand All @@ -123,7 +127,7 @@ def map(
concurrency=concurrency,
buffer_size=concurrency,
ordered=ordered,
via_processes=via_processes,
via=via,
)
)
)
Expand Down
13 changes: 8 additions & 5 deletions streamable/iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
cast,
)

with suppress(ImportError):
from typing import Literal

from streamable.util.functiontools import catch_and_raise_as

T = TypeVar("T")
Expand Down Expand Up @@ -417,19 +420,19 @@ def __init__(
concurrency: int,
buffer_size: int,
ordered: bool,
via_processes: bool,
via: "Literal['thread', 'process']",
) -> None:
super().__init__(iterator, buffer_size, ordered)
self.transformation = transformation
self.concurrency = concurrency
self.executor: Executor
self.via_processes = via_processes
self.via = via

def _context_manager(self) -> ContextManager:
if self.via_processes:
self.executor = ProcessPoolExecutor(max_workers=self.concurrency)
else:
if self.via == "thread":
self.executor = ThreadPoolExecutor(max_workers=self.concurrency)
if self.via == "process":
self.executor = ProcessPoolExecutor(max_workers=self.concurrency)
return self.executor

# picklable
Expand Down
27 changes: 17 additions & 10 deletions streamable/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import logging
from contextlib import suppress
from multiprocessing import get_logger
from typing import (
TYPE_CHECKING,
Expand All @@ -21,6 +22,9 @@
overload,
)

with suppress(ImportError):
from typing import Literal

from streamable.util.constants import NO_REPLACEMENT
from streamable.util.validationtools import (
validate_concurrency,
Expand All @@ -29,6 +33,7 @@
validate_throttle_interval,
validate_throttle_per_period,
validate_truncate_args,
validate_via,
)

# fmt: off
Expand Down Expand Up @@ -235,7 +240,7 @@ def foreach(
effect: Callable[[T], Any],
concurrency: int = 1,
ordered: bool = True,
via_processes: bool = False,
via: "Literal['thread', 'process']" = "thread",
) -> "Stream[T]":
"""
For each upstream element, yields it after having called `effect` on it.
Expand All @@ -245,12 +250,13 @@ def foreach(
effect (Callable[[T], Any]): The function to be applied to each element as a side effect.
concurrency (int): Represents both the number of threads used to concurrently apply the `effect` and the size of the buffer containing not-yet-yielded elements. If the buffer is full, the iteration over the upstream is stopped until some elements are yielded out of the buffer. (default is 1, meaning no multithreading).
ordered (bool): If `concurrency` > 1, whether to preserve the order of upstream elements or to yield them as soon as they are processed (default preserves order).
via_processes (bool): If `concurrency` > 1, applies `effect` concurrently using processes instead of threads (default is threads).
via ("thread" or "process"): If `concurrency` > 1, whether to apply `transformation` using processes or threads (default via threads).
Returns:
Stream[T]: A stream of upstream elements, unchanged.
"""
validate_concurrency(concurrency)
return ForeachStream(self, effect, concurrency, ordered, via_processes)
validate_via(via)
return ForeachStream(self, effect, concurrency, ordered, via)

def aforeach(
self,
Expand Down Expand Up @@ -302,7 +308,7 @@ def map(
transformation: Callable[[T], U],
concurrency: int = 1,
ordered: bool = True,
via_processes: bool = False,
via: "Literal['thread', 'process']" = "thread",
) -> "Stream[U]":
"""
Applies `transformation` on upstream elements and yields the results.
Expand All @@ -311,12 +317,13 @@ def map(
transformation (Callable[[T], R]): The function to be applied to each element.
concurrency (int): Represents both the number of threads used to concurrently apply `transformation` and the size of the buffer containing not-yet-yielded results. If the buffer is full, the iteration over the upstream is stopped until some results are yielded out of the buffer. (default is 1, meaning no multithreading).
ordered (bool): If `concurrency` > 1, whether to preserve the order of upstream elements or to yield them as soon as they are processed (default preserves order).
via_processes (bool): If `concurrency` > 1, applies `transformation` concurrently using processes instead of threads (default via threads).
via ("thread" or "process"): If `concurrency` > 1, whether to apply `transformation` using processes or threads (default via threads).
Returns:
Stream[R]: A stream of transformed elements.
"""
validate_concurrency(concurrency)
return MapStream(self, transformation, concurrency, ordered, via_processes)
validate_via(via)
return MapStream(self, transformation, concurrency, ordered, via)

def amap(
self,
Expand Down Expand Up @@ -464,13 +471,13 @@ def __init__(
effect: Callable[[T], Any],
concurrency: int,
ordered: bool,
via_processes: bool,
via: "Literal['thread', 'process']",
) -> None:
super().__init__(upstream)
self._effect = effect
self._concurrency = concurrency
self._ordered = ordered
self._via_processes = via_processes
self._via = via

def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_foreach_stream(self)
Expand Down Expand Up @@ -517,13 +524,13 @@ def __init__(
transformation: Callable[[T], U],
concurrency: int,
ordered: bool,
via_processes: bool,
via: "Literal['thread', 'process']",
) -> None:
super().__init__(upstream)
self._transformation = transformation
self._concurrency = concurrency
self._ordered = ordered
self._via_processes = via_processes
self._via = via

def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_map_stream(self)
Expand Down
17 changes: 11 additions & 6 deletions streamable/util/validationtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,36 @@ def validate_iterator(iterator: Iterator):
)


def validate_concurrency(concurrency: int):
def validate_concurrency(concurrency: int) -> None:
if concurrency < 1:
raise ValueError(
f"`concurrency` should be greater or equal to 1, but got {concurrency}."
)


def validate_group_size(size: Optional[int]):
def validate_via(via: str) -> None:
if via not in ["thread", "process"]:
raise TypeError(f"`via` should be 'thread' or 'process', but got {repr(via)}.")


def validate_group_size(size: Optional[int]) -> None:
if size is not None and size < 1:
raise ValueError(f"`size` should be None or >= 1 but got {size}.")


def validate_group_interval(interval: Optional[datetime.timedelta]):
def validate_group_interval(interval: Optional[datetime.timedelta]) -> None:
if interval is not None and interval <= datetime.timedelta(0):
raise ValueError(f"`interval` should be positive but got {repr(interval)}.")


def validate_throttle_per_period(per_period_arg_name: str, value: int):
def validate_throttle_per_period(per_period_arg_name: str, value: int) -> None:
if value < 1:
raise ValueError(
f"`{per_period_arg_name}` is the maximum number of elements to yield {' '.join(per_period_arg_name.split('_'))}, it must be >= 1 but got {value}."
)


def validate_throttle_interval(interval: datetime.timedelta):
def validate_throttle_interval(interval: datetime.timedelta) -> None:
if interval < datetime.timedelta(0):
raise ValueError(
f"`interval` is the minimum span of time between yields, it must not be negative but got {repr(interval)}."
Expand All @@ -45,7 +50,7 @@ def validate_throttle_interval(interval: datetime.timedelta):

def validate_truncate_args(
count: Optional[int] = None, when: Optional[Callable[[T], Any]] = None
):
) -> None:
if count is None:
if when is None:
raise ValueError(f"`count` and `when` can't be both None.")
Expand Down
4 changes: 2 additions & 2 deletions streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def visit_foreach_stream(self, stream: ForeachStream[T]) -> Iterator[T]:
sidify(stream._effect),
stream._concurrency,
stream._ordered,
stream._via_processes,
stream._via,
)
)

Expand Down Expand Up @@ -84,7 +84,7 @@ def visit_map_stream(self, stream: MapStream[U, T]) -> Iterator[T]:
stream.upstream.accept(IteratorVisitor[U]()),
concurrency=stream._concurrency,
ordered=stream._ordered,
via_processes=stream._via_processes,
via=stream._via,
)

def visit_amap_stream(self, stream: AMapStream[U, T]) -> Iterator[T]:
Expand Down
4 changes: 2 additions & 2 deletions streamable/visitors/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def visit_flatten_stream(self, stream: FlattenStream[T]) -> str:

def visit_foreach_stream(self, stream: ForeachStream[T]) -> str:
self.methods_reprs.append(
f"foreach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}, via_processes={self.to_string(stream._via_processes)})"
f"foreach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}, via={self.to_string(stream._via)})"
)
return stream.upstream.accept(self)

Expand All @@ -67,7 +67,7 @@ def visit_group_stream(self, stream: GroupStream[U]) -> str:

def visit_map_stream(self, stream: MapStream[U, T]) -> str:
self.methods_reprs.append(
f"map({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}, via_processes={self.to_string(stream._via_processes)})"
f"map({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}, via={self.to_string(stream._via)})"
)
return stream.upstream.accept(self)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_map_example(self) -> None:
def test_process_concurrent_map_example(self) -> None:
state: List[int] = []
n_integers: int = integers.map(
state.append, concurrency=4, via_processes=True
state.append, concurrency=4, via="process"
).count()
assert n_integers == 10
assert state == [] # main process's state not mutated
Expand Down
37 changes: 26 additions & 11 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,21 @@ class CustomCallable:
self.assertEqual(
"""(
Stream(range(0, 256))
.map(<lambda>, concurrency=1, ordered=True, via_processes=True)
.map(<lambda>, concurrency=1, ordered=True, via='process')
)""",
str(Stream(src).map(lambda _: _, via_processes=True)),
str(Stream(src).map(lambda _: _, via="process")),
msg="`repr` should work as expected on a stream with 1 operation",
)
self.assertEqual(
"""(
Stream(range(0, 256))
.truncate(count=1024, when=<lambda>)
.filter(bool)
.map(<lambda>, concurrency=1, ordered=True, via_processes=False)
.map(<lambda>, concurrency=1, ordered=True, via='thread')
.filter(star(bool))
.foreach(<lambda>, concurrency=1, ordered=True, via_processes=False)
.foreach(<lambda>, concurrency=1, ordered=True, via='thread')
.aforeach(async_identity, concurrency=1, ordered=True)
.map(CustomCallable(...), concurrency=1, ordered=True, via_processes=False)
.map(CustomCallable(...), concurrency=1, ordered=True, via='thread')
.amap(async_identity, concurrency=1, ordered=True)
.group(size=100, by=None, interval=None)
.observe('groups')
Expand Down Expand Up @@ -322,12 +322,13 @@ def test_sanitize_concurrency(self, method, args) -> None:
stream = Stream(src)
with self.assertRaises(
TypeError,
msg=f"{method} should be raising TypeError for non-int concurrency.",
msg=f"`{method}` should be raising TypeError for non-int concurrency.",
):
method(stream, *args, concurrency="1")

with self.assertRaises(
ValueError, msg=f"{method} should be raising ValueError for concurrency=0."
ValueError,
msg=f"`{method}` should be raising ValueError for concurrency=0.",
):
method(stream, *args, concurrency=0)

Expand All @@ -338,6 +339,20 @@ def test_sanitize_concurrency(self, method, args) -> None:
msg=f"It must be ok to call {method} with concurrency={concurrency}.",
)

@parameterized.expand(
[
(Stream.map,),
(Stream.foreach,),
]
)
def test_sanitize_via(self, method) -> None:
with self.assertRaisesRegex(
TypeError,
"`via` should be 'thread' or 'process', but got 'foo'.",
msg=f"`{method}` must raise a TypeError for invalid via",
):
method(Stream(src), identity, via="foo")

@parameterized.expand(
[
[1],
Expand Down Expand Up @@ -376,16 +391,16 @@ def local_identity(x):
"Can't pickle",
msg="process-based concurrency should not be able to serialize a lambda or a local func",
):
list(Stream(src).map(f, concurrency=2, via_processes=True))
list(Stream(src).map(f, concurrency=2, via="process"))

sleeps = [0.01, 1, 0.01]
state: List[str] = []
expected_result_list: List[str] = list(order_mutation(map(str, sleeps)))
stream = (
Stream(sleeps)
.foreach(identity_sleep, concurrency=2, ordered=ordered, via_processes=True)
.map(str, concurrency=2, ordered=True, via_processes=True)
.foreach(state.append, concurrency=2, ordered=True, via_processes=True)
.foreach(identity_sleep, concurrency=2, ordered=ordered, via="process")
.map(str, concurrency=2, ordered=True, via="process")
.foreach(state.append, concurrency=2, ordered=True, via="process")
.foreach(lambda _: state.append(""), concurrency=1, ordered=True)
)
self.assertListEqual(
Expand Down
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# to show the CHANGELOG: git log -- version.py
__version__ = "1.2.2"
__version__ = "1.3.0-rc"

0 comments on commit f1077bb

Please sign in to comment.