Skip to content

Commit

Permalink
PT: prevent overflow when tracking padding in minibatches
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed Dec 2, 2024
1 parent e8cc3b3 commit 40566e2
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,13 @@ def train_epoch(self):
if not _has_data[0]:
break

# convert values from torch int32 to Python ints to prevent overflow
keys_w_seq_len = [k for k in extern_data_raw if f"{k}:seq_len" in extern_data_raw]
total_data_size_packed += NumbersDict(
{k: sum(extern_data_raw[f"{k}:seq_len"]) for k in keys_w_seq_len},
{k: int(sum(extern_data_raw[f"{k}:seq_len"])) for k in keys_w_seq_len},
)
total_data_size_padded += NumbersDict(
{k: util.prod(extern_data_raw[k].shape[:2]) for k in keys_w_seq_len},
{k: int(util.prod(extern_data_raw[k].shape[:2])) for k in keys_w_seq_len},
)

num_seqs_ = (
Expand Down Expand Up @@ -523,6 +524,7 @@ def _debug_func() -> torch.Tensor:
total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
total_data_size_packed / total_data_size_padded
)
assert 0.0 <= total_padding_ratio.min_value() <= total_padding_ratio.max_value() <= 1.0
pad_str = ", ".join(f"{k}: {v:.1%}" for k, v in total_padding_ratio.items())
print(
f"Epoch {self.epoch}: Trained {step_idx} steps, {hms(elapsed)} elapsed "
Expand Down

0 comments on commit 40566e2

Please sign in to comment.