Skip to content

Commit

Permalink
PT: add optimizing bucket batching
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed Dec 5, 2024
1 parent 7ee6484 commit 15b19cf
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions returnn/torch/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,33 @@ def __getitem__(self, index):
raise Exception(f"{self.__class__.__name__}.__getitem__ is not supported")


class OptimizingBucketOrderingIterDataPipe(BucketOrderingIterDataPipe):
def __iter__(self):
seq_lens = []
for batch in iter(super()):
seq_lens.extend(int(d[self._length_key].shape[0]) for d in batch)
yield batch
seq_lens.sort()

new_upper_seq_lens = [
max(1, v)
for v in [
*seq_lens[:: len(seq_lens) // len(self._max_seq_lens)][1:],
self._max_seq_lens[-1], # keep old upper bound
]
]
old_batch_sizes = [seq_len * bsize for seq_len, bsize in zip(self._max_seq_lens, self._max_bucket_sizes)]
new_bucket_sizes = [
max(1, round(bsize / seq_len)) for seq_len, bsize in zip(new_upper_seq_lens, old_batch_sizes)
]

self._max_bucket_sizes = new_bucket_sizes
self._max_seq_lens = new_upper_seq_lens

cfg_str = ", ".join(f"{limit}: {size}" for limit, size in zip(new_upper_seq_lens, new_bucket_sizes))
print(f"optimized bucket batching configuration: {cfg_str}", file=log.v3)


def get_batching_iterable_dataset_from_config(
*, dataset: torch.utils.data.IterableDataset, config: Config, train: bool
) -> torch.utils.data.IterableDataset:
Expand Down

0 comments on commit 15b19cf

Please sign in to comment.