diff --git a/streamable/iterators.py b/streamable/iterators.py index 38f1e5e..8e821b3 100644 --- a/streamable/iterators.py +++ b/streamable/iterators.py @@ -149,7 +149,7 @@ def __next__(self) -> U: self._current_iterator_elem = noop_stopiteration(iter)(iterable_elem) -class _GroupIteratorInitMixin(Generic[T]): +class _GroupIterator(Generic[T]): def __init__( self, iterator: Iterator[T], @@ -161,12 +161,20 @@ def __init__( validate_group_interval(interval) self.iterator = iterator self.size = size or cast(int, float("inf")) + self.interval = interval self._interval_seconds = interval.total_seconds() if interval else float("inf") self._to_be_raised: Optional[Exception] = None self._last_group_yielded_at: float = 0 + def _interval_seconds_have_elapsed(self) -> bool: + if not self.interval: + return False + return ( + time.perf_counter() - self._last_group_yielded_at + ) >= self._interval_seconds + -class GroupIterator(_GroupIteratorInitMixin[T], Iterator[List[T]]): +class GroupIterator(_GroupIterator[T], Iterator[List[T]]): def __init__( self, iterator: Iterator[T], @@ -176,13 +184,8 @@ def __init__( super().__init__(iterator, size, interval) self._current_group: List[T] = [] - def _interval_seconds_have_elapsed(self) -> bool: - return ( - time.perf_counter() - self._last_group_yielded_at - ) >= self._interval_seconds - def __next__(self) -> List[T]: - if not self._last_group_yielded_at: + if self.interval and not self._last_group_yielded_at: self._last_group_yielded_at = time.perf_counter() if self._to_be_raised: e, self._to_be_raised = self._to_be_raised, None @@ -198,11 +201,12 @@ def __next__(self) -> List[T]: self._to_be_raised = e group, self._current_group = self._current_group, [] - self._last_group_yielded_at = time.perf_counter() + if self.interval: + self._last_group_yielded_at = time.perf_counter() return group -class GroupbyIterator(_GroupIteratorInitMixin[T], Iterator[Tuple[U, List[T]]]): +class GroupbyIterator(_GroupIterator[T], Iterator[Tuple[U, List[T]]]): def __init__( self, iterator: Iterator[T], @@ -215,11 +219,6 @@ def __init__( self._is_exhausted = False self._groups_by: DefaultDict[U, List[T]] = defaultdict(list) - def _interval_seconds_have_elapsed(self) -> bool: - return ( - time.perf_counter() - self._last_group_yielded_at - ) >= self._interval_seconds - def _group_next_elem(self) -> None: elem = next(self.iterator) self._groups_by[self.by(elem)].append(elem) @@ -244,11 +243,12 @@ def _pop_largest_group(self) -> Tuple[U, List[T]]: return largest_group_key, self._groups_by.pop(largest_group_key) def _return_group(self, group: Tuple[U, List[T]]) -> Tuple[U, List[T]]: - self._last_group_yielded_at = time.perf_counter() + if self.interval: + self._last_group_yielded_at = time.perf_counter() return group def __next__(self) -> Tuple[U, List[T]]: - if not self._last_group_yielded_at: + if self.interval and not self._last_group_yielded_at: self._last_group_yielded_at = time.perf_counter() if self._is_exhausted: if self._groups_by: