Skip to content

Commit

Permalink
fix ctc batch prior (seq as well)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 4, 2025
1 parent 7f78dcc commit bc1c9c2
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2116,7 +2116,7 @@ def __init__(
non_critical_for_restore=True,
)
self.prior_running_mean_momentum = config.typed_value("prior_running_mean_momentum", None)
self.prior_running_mean = None
self.prior_running_mean = None # in std prob, if set
if self.prior_running_mean_momentum is not None:
self.prior_running_mean = rf.Parameter(
[self.wb_target_dim], auxiliary=True, initial=1.0 / self.wb_target_dim.dimension
Expand Down Expand Up @@ -2150,7 +2150,7 @@ def __init__(
self.log_prob_normed_grad_exclude_blank = config.bool(
"log_prob_normed_grad_exclude_blank", self.out_blank_separated
)
self.grad_prior = None
self.grad_prior = None # in std prob, if set
if (
self.log_prob_normed_grad_opts
and self.log_prob_normed_grad_opts.get("prior_running_mean_momentum") is not None
Expand Down Expand Up @@ -2332,12 +2332,18 @@ def _update_running_stats():
log_probs = log_probs_am * self.ctc_am_scale
if self.ctc_prior_scale:
if self.ctc_prior_type == "batch":
# Warning: this is sum, but we want mean!
log_prob_prior = rf.reduce_logsumexp(
log_probs_am, axis=[dim for dim in log_probs_am.dims if dim != self.wb_target_dim]
)
assert log_prob_prior.dims == (self.wb_target_dim,)
elif self.ctc_prior_type == "batch_fixed":
log_prob_prior = rf.reduce_logmeanexp(
log_probs_am, axis=[dim for dim in log_probs_am.dims if dim != self.wb_target_dim]
)
assert log_prob_prior.dims == (self.wb_target_dim,)
elif self.ctc_prior_type == "seq":
log_prob_prior = rf.reduce_logsumexp(
log_prob_prior = rf.reduce_logmeanexp(
log_probs_am, axis=[dim for dim in log_probs_am.dims if dim not in (batch_dim, self.wb_target_dim)]
)
assert log_prob_prior.dims_set == {batch_dim, self.wb_target_dim}
Expand Down

0 comments on commit bc1c9c2

Please sign in to comment.