diff --git a/returnn/torch/data/pipeline.py b/returnn/torch/data/pipeline.py index bb14425e6..ad4e39428 100644 --- a/returnn/torch/data/pipeline.py +++ b/returnn/torch/data/pipeline.py @@ -373,6 +373,33 @@ def __getitem__(self, index): raise Exception(f"{self.__class__.__name__}.__getitem__ is not supported") +class OptimizingBucketOrderingIterDataPipe(BucketOrderingIterDataPipe): + def __iter__(self): + seq_lens = [] + for batch in iter(super()): + seq_lens.extend(int(d[self._length_key].shape[0]) for d in batch) + yield batch + seq_lens.sort() + + new_upper_seq_lens = [ + max(1, v) + for v in [ + *seq_lens[:: len(seq_lens) // len(self._max_seq_lens)][1:], + self._max_seq_lens[-1], # keep old upper bound + ] + ] + old_batch_sizes = [seq_len * bsize for seq_len, bsize in zip(self._max_seq_lens, self._max_bucket_sizes)] + new_bucket_sizes = [ + max(1, round(bsize / seq_len)) for seq_len, bsize in zip(new_upper_seq_lens, old_batch_sizes) + ] + + self._max_bucket_sizes = new_bucket_sizes + self._max_seq_lens = new_upper_seq_lens + + cfg_str = ", ".join(f"{limit}: {size}" for limit, size in zip(new_upper_seq_lens, new_bucket_sizes)) + print(f"optimized bucket batching configuration: {cfg_str}", file=log.v3) + + def get_batching_iterable_dataset_from_config( *, dataset: torch.utils.data.IterableDataset, config: Config, train: bool ) -> torch.utils.data.IterableDataset: