Skip to content

Commit

Permalink
Make loss masking on prompt optional and disabled by default
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed Dec 4, 2024
1 parent 4b1bfce commit 2b75b09
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
3 changes: 2 additions & 1 deletion examples/asr/conf/speech_multitask/fast-conformer_aed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ spl_tokens:
model:
sample_rate: 16000
label_smoothing: 0.0
context_len_for_AR_decoding: 5 # Length of input prompt tokens. For example, in Canary models, we use [BOS,src_lang,task,tgt_lang,pnc] and thus the length is 5
use_loss_mask_for_prompt: false
log_prediction: true # enables logging sample predictions in the output during training

# Important ! Set the prompt format to the class you need
prompt_format: ??? # Options supported: ["canary"]
prompt_defaults: null

model_defaults:
asr_enc_hidden: 1024
Expand Down
17 changes: 8 additions & 9 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,11 +676,6 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb):
input_ids, labels = batch.get_decoder_inputs_outputs()
input_ids_lens = batch.prompted_transcript_lens - 1

num_frames = batch.audio_lens.sum().float()
num_tokens = input_ids_lens.sum().float()
tot_frames = torch.as_tensor(batch.audio.numel(), device=num_frames.device, dtype=torch.float)
tot_tokens = torch.as_tensor(batch.prompted_transcript.numel(), device=num_frames.device, dtype=torch.float)

num_frames = batch.audio_lens.sum().float()
num_tokens = batch.prompted_transcript_lens.sum().float()
tot_frames = torch.as_tensor(batch.audio.numel(), device=num_frames.device, dtype=torch.float)
Expand All @@ -696,8 +691,10 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb):
# Mask components: 1) discard padding & 2) discard prompt (notice the negation)
# For a full decoder sequence O with len M, the loss mask skips the first element,
# covering the remaining M-1 elements - hence we subtract 1 from prompt lens to account BOS.
maxlen = batch.prompted_transcript.shape[1] - 1
loss_mask = lens_to_mask(input_ids_lens, maxlen) & ~lens_to_mask(batch.prompt_lens - 1, maxlen)
loss_mask = None
if self.cfg.get("use_loss_mask_for_prompt", False):
maxlen = batch.prompted_transcript.shape[1] - 1
loss_mask = lens_to_mask(input_ids_lens, maxlen) & ~lens_to_mask(batch.prompt_lens - 1, maxlen)
audio_loss = self.loss(log_probs=transf_log_probs, labels=labels, output_mask=loss_mask)

tensorboard_logs = {
Expand Down Expand Up @@ -726,8 +723,10 @@ def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, datalo
# Mask components: 1) discard padding & 2) discard prompt (notice the negation)
# For a full decoder sequence O with len M, the loss mask skips the first element,
# covering the remaining M-1 elements - hence we subtract 1 from prompt lens to account BOS.
maxlen = batch.prompted_transcript.shape[1] - 1
loss_mask = lens_to_mask(input_ids_lens, maxlen) & ~lens_to_mask(batch.prompt_lens - 1, maxlen)
loss_mask = None
if self.cfg.get("use_loss_mask_for_prompt", False):
maxlen = batch.prompted_transcript.shape[1] - 1
loss_mask = lens_to_mask(input_ids_lens, maxlen) & ~lens_to_mask(batch.prompt_lens - 1, maxlen)
transf_loss = self.loss(log_probs=transf_log_probs, labels=labels, output_mask=loss_mask)
self.val_loss(loss=transf_loss, num_measurements=loss_mask.long().sum())
output_dict = {f'{eval_mode}_loss': transf_loss}
Expand Down

0 comments on commit 2b75b09

Please sign in to comment.