Skip to content

Commit

Permalink
.group[by]: avoid calling perf_counter if interval is not set
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Dec 28, 2024
1 parent c89941e commit 926532e
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions streamable/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __next__(self) -> U:
self._current_iterator_elem = noop_stopiteration(iter)(iterable_elem)


class _GroupIteratorInitMixin(Generic[T]):
class _GroupIterator(Generic[T]):
def __init__(
self,
iterator: Iterator[T],
Expand All @@ -161,12 +161,20 @@ def __init__(
validate_group_interval(interval)
self.iterator = iterator
self.size = size or cast(int, float("inf"))
self.interval = interval
self._interval_seconds = interval.total_seconds() if interval else float("inf")
self._to_be_raised: Optional[Exception] = None
self._last_group_yielded_at: float = 0

def _interval_seconds_have_elapsed(self) -> bool:
if not self.interval:
return False
return (
time.perf_counter() - self._last_group_yielded_at
) >= self._interval_seconds


class GroupIterator(_GroupIteratorInitMixin[T], Iterator[List[T]]):
class GroupIterator(_GroupIterator[T], Iterator[List[T]]):
def __init__(
self,
iterator: Iterator[T],
Expand All @@ -176,13 +184,8 @@ def __init__(
super().__init__(iterator, size, interval)
self._current_group: List[T] = []

def _interval_seconds_have_elapsed(self) -> bool:
return (
time.perf_counter() - self._last_group_yielded_at
) >= self._interval_seconds

def __next__(self) -> List[T]:
if not self._last_group_yielded_at:
if self.interval and not self._last_group_yielded_at:
self._last_group_yielded_at = time.perf_counter()
if self._to_be_raised:
e, self._to_be_raised = self._to_be_raised, None
Expand All @@ -198,11 +201,12 @@ def __next__(self) -> List[T]:
self._to_be_raised = e

group, self._current_group = self._current_group, []
self._last_group_yielded_at = time.perf_counter()
if self.interval:
self._last_group_yielded_at = time.perf_counter()
return group


class GroupbyIterator(_GroupIteratorInitMixin[T], Iterator[Tuple[U, List[T]]]):
class GroupbyIterator(_GroupIterator[T], Iterator[Tuple[U, List[T]]]):
def __init__(
self,
iterator: Iterator[T],
Expand All @@ -215,11 +219,6 @@ def __init__(
self._is_exhausted = False
self._groups_by: DefaultDict[U, List[T]] = defaultdict(list)

def _interval_seconds_have_elapsed(self) -> bool:
return (
time.perf_counter() - self._last_group_yielded_at
) >= self._interval_seconds

def _group_next_elem(self) -> None:
elem = next(self.iterator)
self._groups_by[self.by(elem)].append(elem)
Expand All @@ -244,11 +243,12 @@ def _pop_largest_group(self) -> Tuple[U, List[T]]:
return largest_group_key, self._groups_by.pop(largest_group_key)

def _return_group(self, group: Tuple[U, List[T]]) -> Tuple[U, List[T]]:
self._last_group_yielded_at = time.perf_counter()
if self.interval:
self._last_group_yielded_at = time.perf_counter()
return group

def __next__(self) -> Tuple[U, List[T]]:
if not self._last_group_yielded_at:
if self.interval and not self._last_group_yielded_at:
self._last_group_yielded_at = time.perf_counter()
if self._is_exhausted:
if self._groups_by:
Expand Down

0 comments on commit 926532e

Please sign in to comment.