diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py index 4cd663a34..4e32e2258 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py @@ -2116,7 +2116,7 @@ def __init__( non_critical_for_restore=True, ) self.prior_running_mean_momentum = config.typed_value("prior_running_mean_momentum", None) - self.prior_running_mean = None + self.prior_running_mean = None # in std prob, if set if self.prior_running_mean_momentum is not None: self.prior_running_mean = rf.Parameter( [self.wb_target_dim], auxiliary=True, initial=1.0 / self.wb_target_dim.dimension @@ -2150,7 +2150,7 @@ def __init__( self.log_prob_normed_grad_exclude_blank = config.bool( "log_prob_normed_grad_exclude_blank", self.out_blank_separated ) - self.grad_prior = None + self.grad_prior = None # in std prob, if set if ( self.log_prob_normed_grad_opts and self.log_prob_normed_grad_opts.get("prior_running_mean_momentum") is not None @@ -2332,12 +2332,18 @@ def _update_running_stats(): log_probs = log_probs_am * self.ctc_am_scale if self.ctc_prior_scale: if self.ctc_prior_type == "batch": + # Warning: this is sum, but we want mean! log_prob_prior = rf.reduce_logsumexp( log_probs_am, axis=[dim for dim in log_probs_am.dims if dim != self.wb_target_dim] ) assert log_prob_prior.dims == (self.wb_target_dim,) + elif self.ctc_prior_type == "batch_fixed": + log_prob_prior = rf.reduce_logmeanexp( + log_probs_am, axis=[dim for dim in log_probs_am.dims if dim != self.wb_target_dim] + ) + assert log_prob_prior.dims == (self.wb_target_dim,) elif self.ctc_prior_type == "seq": - log_prob_prior = rf.reduce_logsumexp( + log_prob_prior = rf.reduce_logmeanexp( log_probs_am, axis=[dim for dim in log_probs_am.dims if dim not in (batch_dim, self.wb_target_dim)] ) assert log_prob_prior.dims_set == {batch_dim, self.wb_target_dim}