Skip to content

Commit

Permalink
better check
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 10, 2025
1 parent e258e29 commit dd72c1a
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,6 @@ def finish(self, state: LMState):
assert seq_len <= max_seq_len
results = fl_decoder.decode(emissions_ptr, seq_len, model.wb_target_dim.dimension)
hyps_per_batch = [result.tokens for result in results]
assert all(len(hyp) == seq_len for hyp in hyps_per_batch)
scores_per_batch = [result.score for result in results]
best_word_seq = [
model.wb_target_dim.vocab.id_to_label(label_idx) if label_idx >= 0 else str(label_idx)
Expand All @@ -725,6 +724,9 @@ def finish(self, state: LMState):
f" LM recalc whole seq count {fl_lm._count_recalc_whole_seq}"
f" mem usage {dev_s}: {' '.join(_collect_mem_stats())}"
)
assert all(
len(hyp) == seq_len for hyp in hyps_per_batch
), f"seq_len {seq_len}, hyps lens {[len(hyp) for hyp in hyps_per_batch]}"
if len(results) >= n_best:
hyps_per_batch = hyps_per_batch[:n_best]
scores_per_batch = scores_per_batch[:n_best]
Expand Down

0 comments on commit dd72c1a

Please sign in to comment.