diff --git a/README.md b/README.md index 0377871..544468f 100644 --- a/README.md +++ b/README.md @@ -130,16 +130,12 @@ assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] ### process-based concurrency -> Set `via_processes`: +> Set `via="process"`: ```python if __name__ == "__main__": state: List[int] = [] - n_integers: int = ( - integers - .map(state.append, concurrency=4, via_processes=True) - .count() - ) + n_integers: int = integers.map(state.append, concurrency=4, via="process").count() assert n_integers == 10 assert state == [] # main process's state not mutated ``` @@ -199,7 +195,7 @@ assert list(self_printing_integers) == list(integers) # triggers the printing ### process-based concurrency -> Like `.map` it has an optional `via_processes` parameter. +> Like for `.map`, set the parameter `via="process"`. ### async-based concurrency diff --git a/streamable/functions.py b/streamable/functions.py index 8b4a835..ea9d351 100644 --- a/streamable/functions.py +++ b/streamable/functions.py @@ -1,5 +1,6 @@ import builtins import datetime +from contextlib import suppress from typing import ( Any, Callable, @@ -13,6 +14,9 @@ cast, ) +with suppress(ImportError): + from typing import Literal + from streamable.iters import ( AsyncConcurrentMappingIterable, ByKeyGroupingIterator, @@ -105,7 +109,7 @@ def map( iterator: Iterator[T], concurrency: int = 1, ordered: bool = True, - via_processes: bool = False, + via: "Literal['thread', 'process']" = "thread", ) -> Iterator[U]: validate_iterator(iterator) validate_concurrency(concurrency) @@ -123,7 +127,7 @@ def map( concurrency=concurrency, buffer_size=concurrency, ordered=ordered, - via_processes=via_processes, + via=via, ) ) ) diff --git a/streamable/iters.py b/streamable/iters.py index 596ef93..7540215 100644 --- a/streamable/iters.py +++ b/streamable/iters.py @@ -27,6 +27,9 @@ cast, ) +with suppress(ImportError): + from typing import Literal + from streamable.util.functiontools import catch_and_raise_as T = TypeVar("T") @@ -417,19 +420,19 @@ def __init__( concurrency: int, buffer_size: int, ordered: bool, - via_processes: bool, + via: "Literal['thread', 'process']", ) -> None: super().__init__(iterator, buffer_size, ordered) self.transformation = transformation self.concurrency = concurrency self.executor: Executor - self.via_processes = via_processes + self.via = via def _context_manager(self) -> ContextManager: - if self.via_processes: - self.executor = ProcessPoolExecutor(max_workers=self.concurrency) - else: + if self.via == "thread": self.executor = ThreadPoolExecutor(max_workers=self.concurrency) + if self.via == "process": + self.executor = ProcessPoolExecutor(max_workers=self.concurrency) return self.executor # picklable diff --git a/streamable/stream.py b/streamable/stream.py index df0eeba..ac9619a 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -1,5 +1,6 @@ import datetime import logging +from contextlib import suppress from multiprocessing import get_logger from typing import ( TYPE_CHECKING, @@ -21,6 +22,9 @@ overload, ) +with suppress(ImportError): + from typing import Literal + from streamable.util.constants import NO_REPLACEMENT from streamable.util.validationtools import ( validate_concurrency, @@ -29,6 +33,7 @@ validate_throttle_interval, validate_throttle_per_period, validate_truncate_args, + validate_via, ) # fmt: off @@ -235,7 +240,7 @@ def foreach( effect: Callable[[T], Any], concurrency: int = 1, ordered: bool = True, - via_processes: bool = False, + via: "Literal['thread', 'process']" = "thread", ) -> "Stream[T]": """ For each upstream element, yields it after having called `effect` on it. @@ -245,12 +250,13 @@ def foreach( effect (Callable[[T], Any]): The function to be applied to each element as a side effect. concurrency (int): Represents both the number of threads used to concurrently apply the `effect` and the size of the buffer containing not-yet-yielded elements. If the buffer is full, the iteration over the upstream is stopped until some elements are yielded out of the buffer. (default is 1, meaning no multithreading). ordered (bool): If `concurrency` > 1, whether to preserve the order of upstream elements or to yield them as soon as they are processed (default preserves order). - via_processes (bool): If `concurrency` > 1, applies `effect` concurrently using processes instead of threads (default is threads). + via ("thread" or "process"): If `concurrency` > 1, whether to apply `transformation` using processes or threads (default via threads). Returns: Stream[T]: A stream of upstream elements, unchanged. """ validate_concurrency(concurrency) - return ForeachStream(self, effect, concurrency, ordered, via_processes) + validate_via(via) + return ForeachStream(self, effect, concurrency, ordered, via) def aforeach( self, @@ -302,7 +308,7 @@ def map( transformation: Callable[[T], U], concurrency: int = 1, ordered: bool = True, - via_processes: bool = False, + via: "Literal['thread', 'process']" = "thread", ) -> "Stream[U]": """ Applies `transformation` on upstream elements and yields the results. @@ -311,12 +317,13 @@ def map( transformation (Callable[[T], R]): The function to be applied to each element. concurrency (int): Represents both the number of threads used to concurrently apply `transformation` and the size of the buffer containing not-yet-yielded results. If the buffer is full, the iteration over the upstream is stopped until some results are yielded out of the buffer. (default is 1, meaning no multithreading). ordered (bool): If `concurrency` > 1, whether to preserve the order of upstream elements or to yield them as soon as they are processed (default preserves order). - via_processes (bool): If `concurrency` > 1, applies `transformation` concurrently using processes instead of threads (default via threads). + via ("thread" or "process"): If `concurrency` > 1, whether to apply `transformation` using processes or threads (default via threads). Returns: Stream[R]: A stream of transformed elements. """ validate_concurrency(concurrency) - return MapStream(self, transformation, concurrency, ordered, via_processes) + validate_via(via) + return MapStream(self, transformation, concurrency, ordered, via) def amap( self, @@ -464,13 +471,13 @@ def __init__( effect: Callable[[T], Any], concurrency: int, ordered: bool, - via_processes: bool, + via: "Literal['thread', 'process']", ) -> None: super().__init__(upstream) self._effect = effect self._concurrency = concurrency self._ordered = ordered - self._via_processes = via_processes + self._via = via def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_foreach_stream(self) @@ -517,13 +524,13 @@ def __init__( transformation: Callable[[T], U], concurrency: int, ordered: bool, - via_processes: bool, + via: "Literal['thread', 'process']", ) -> None: super().__init__(upstream) self._transformation = transformation self._concurrency = concurrency self._ordered = ordered - self._via_processes = via_processes + self._via = via def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_map_stream(self) diff --git a/streamable/util/validationtools.py b/streamable/util/validationtools.py index 046b648..fea0e5e 100644 --- a/streamable/util/validationtools.py +++ b/streamable/util/validationtools.py @@ -12,31 +12,36 @@ def validate_iterator(iterator: Iterator): ) -def validate_concurrency(concurrency: int): +def validate_concurrency(concurrency: int) -> None: if concurrency < 1: raise ValueError( f"`concurrency` should be greater or equal to 1, but got {concurrency}." ) -def validate_group_size(size: Optional[int]): +def validate_via(via: str) -> None: + if via not in ["thread", "process"]: + raise TypeError(f"`via` should be 'thread' or 'process', but got {repr(via)}.") + + +def validate_group_size(size: Optional[int]) -> None: if size is not None and size < 1: raise ValueError(f"`size` should be None or >= 1 but got {size}.") -def validate_group_interval(interval: Optional[datetime.timedelta]): +def validate_group_interval(interval: Optional[datetime.timedelta]) -> None: if interval is not None and interval <= datetime.timedelta(0): raise ValueError(f"`interval` should be positive but got {repr(interval)}.") -def validate_throttle_per_period(per_period_arg_name: str, value: int): +def validate_throttle_per_period(per_period_arg_name: str, value: int) -> None: if value < 1: raise ValueError( f"`{per_period_arg_name}` is the maximum number of elements to yield {' '.join(per_period_arg_name.split('_'))}, it must be >= 1 but got {value}." ) -def validate_throttle_interval(interval: datetime.timedelta): +def validate_throttle_interval(interval: datetime.timedelta) -> None: if interval < datetime.timedelta(0): raise ValueError( f"`interval` is the minimum span of time between yields, it must not be negative but got {repr(interval)}." @@ -45,7 +50,7 @@ def validate_throttle_interval(interval: datetime.timedelta): def validate_truncate_args( count: Optional[int] = None, when: Optional[Callable[[T], Any]] = None -): +) -> None: if count is None: if when is None: raise ValueError(f"`count` and `when` can't be both None.") diff --git a/streamable/visitors/iterator.py b/streamable/visitors/iterator.py index 008cc45..533fba1 100644 --- a/streamable/visitors/iterator.py +++ b/streamable/visitors/iterator.py @@ -53,7 +53,7 @@ def visit_foreach_stream(self, stream: ForeachStream[T]) -> Iterator[T]: sidify(stream._effect), stream._concurrency, stream._ordered, - stream._via_processes, + stream._via, ) ) @@ -84,7 +84,7 @@ def visit_map_stream(self, stream: MapStream[U, T]) -> Iterator[T]: stream.upstream.accept(IteratorVisitor[U]()), concurrency=stream._concurrency, ordered=stream._ordered, - via_processes=stream._via_processes, + via=stream._via, ) def visit_amap_stream(self, stream: AMapStream[U, T]) -> Iterator[T]: diff --git a/streamable/visitors/representation.py b/streamable/visitors/representation.py index 00e8f14..8ee590c 100644 --- a/streamable/visitors/representation.py +++ b/streamable/visitors/representation.py @@ -49,7 +49,7 @@ def visit_flatten_stream(self, stream: FlattenStream[T]) -> str: def visit_foreach_stream(self, stream: ForeachStream[T]) -> str: self.methods_reprs.append( - f"foreach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}, via_processes={self.to_string(stream._via_processes)})" + f"foreach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}, via={self.to_string(stream._via)})" ) return stream.upstream.accept(self) @@ -67,7 +67,7 @@ def visit_group_stream(self, stream: GroupStream[U]) -> str: def visit_map_stream(self, stream: MapStream[U, T]) -> str: self.methods_reprs.append( - f"map({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}, via_processes={self.to_string(stream._via_processes)})" + f"map({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}, via={self.to_string(stream._via)})" ) return stream.upstream.accept(self) diff --git a/tests/test_readme.py b/tests/test_readme.py index 7a67d1c..9b335d6 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -58,7 +58,7 @@ def test_map_example(self) -> None: def test_process_concurrent_map_example(self) -> None: state: List[int] = [] n_integers: int = integers.map( - state.append, concurrency=4, via_processes=True + state.append, concurrency=4, via="process" ).count() assert n_integers == 10 assert state == [] # main process's state not mutated diff --git a/tests/test_stream.py b/tests/test_stream.py index e69cf49..73e68aa 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -242,9 +242,9 @@ class CustomCallable: self.assertEqual( """( Stream(range(0, 256)) - .map(, concurrency=1, ordered=True, via_processes=True) + .map(, concurrency=1, ordered=True, via='process') )""", - str(Stream(src).map(lambda _: _, via_processes=True)), + str(Stream(src).map(lambda _: _, via="process")), msg="`repr` should work as expected on a stream with 1 operation", ) self.assertEqual( @@ -252,11 +252,11 @@ class CustomCallable: Stream(range(0, 256)) .truncate(count=1024, when=) .filter(bool) - .map(, concurrency=1, ordered=True, via_processes=False) + .map(, concurrency=1, ordered=True, via='thread') .filter(star(bool)) - .foreach(, concurrency=1, ordered=True, via_processes=False) + .foreach(, concurrency=1, ordered=True, via='thread') .aforeach(async_identity, concurrency=1, ordered=True) - .map(CustomCallable(...), concurrency=1, ordered=True, via_processes=False) + .map(CustomCallable(...), concurrency=1, ordered=True, via='thread') .amap(async_identity, concurrency=1, ordered=True) .group(size=100, by=None, interval=None) .observe('groups') @@ -322,12 +322,13 @@ def test_sanitize_concurrency(self, method, args) -> None: stream = Stream(src) with self.assertRaises( TypeError, - msg=f"{method} should be raising TypeError for non-int concurrency.", + msg=f"`{method}` should be raising TypeError for non-int concurrency.", ): method(stream, *args, concurrency="1") with self.assertRaises( - ValueError, msg=f"{method} should be raising ValueError for concurrency=0." + ValueError, + msg=f"`{method}` should be raising ValueError for concurrency=0.", ): method(stream, *args, concurrency=0) @@ -338,6 +339,20 @@ def test_sanitize_concurrency(self, method, args) -> None: msg=f"It must be ok to call {method} with concurrency={concurrency}.", ) + @parameterized.expand( + [ + (Stream.map,), + (Stream.foreach,), + ] + ) + def test_sanitize_via(self, method) -> None: + with self.assertRaisesRegex( + TypeError, + "`via` should be 'thread' or 'process', but got 'foo'.", + msg=f"`{method}` must raise a TypeError for invalid via", + ): + method(Stream(src), identity, via="foo") + @parameterized.expand( [ [1], @@ -376,16 +391,16 @@ def local_identity(x): "Can't pickle", msg="process-based concurrency should not be able to serialize a lambda or a local func", ): - list(Stream(src).map(f, concurrency=2, via_processes=True)) + list(Stream(src).map(f, concurrency=2, via="process")) sleeps = [0.01, 1, 0.01] state: List[str] = [] expected_result_list: List[str] = list(order_mutation(map(str, sleeps))) stream = ( Stream(sleeps) - .foreach(identity_sleep, concurrency=2, ordered=ordered, via_processes=True) - .map(str, concurrency=2, ordered=True, via_processes=True) - .foreach(state.append, concurrency=2, ordered=True, via_processes=True) + .foreach(identity_sleep, concurrency=2, ordered=ordered, via="process") + .map(str, concurrency=2, ordered=True, via="process") + .foreach(state.append, concurrency=2, ordered=True, via="process") .foreach(lambda _: state.append(""), concurrency=1, ordered=True) ) self.assertListEqual( diff --git a/version.py b/version.py index eee25d1..ee6b65f 100644 --- a/version.py +++ b/version.py @@ -1,2 +1,2 @@ # to show the CHANGELOG: git log -- version.py -__version__ = "1.2.2" +__version__ = "1.3.0-rc"