diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/model_ext/ctc_sep_net.py b/users/zeyer/experiments/exp2024_04_23_baselines/model_ext/ctc_sep_net.py index 16989cc8f..0973f4eb4 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/model_ext/ctc_sep_net.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/model_ext/ctc_sep_net.py @@ -641,6 +641,20 @@ def ctc_training_with_sep_net( use_normalized_loss=use_normalized_loss, ) + sep_loss = ctc_loss( + logits=sep_log_probs, + logits_normalized=True, + targets=targets, + input_spatial_dim=enc_spatial_dim, + targets_spatial_dim=targets_spatial_dim, + blank_index=model.blank_idx, + ) + sep_loss.mark_as_loss( + "sep_ctc", + custom_inv_norm_factor=targets_spatial_dim.get_size_tensor(), + use_normalized_loss=use_normalized_loss, + ) + if model.decoder: # potentially also other types but just assume # noinspection PyTypeChecker