Skip to content

Commit

Permalink
grad prior per layer
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 5, 2025
1 parent 1fb1106 commit 3e4ec87
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
23 changes: 17 additions & 6 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2170,6 +2170,13 @@ def __init__(
# - static_prior is supposed to be static, also with non_critical_for_restore=True.
_grad_prior_dim = self.target_dim if self.log_prob_normed_grad_exclude_blank else self.wb_target_dim
self.grad_prior = rf.Parameter([_grad_prior_dim], auxiliary=True, initial=1.0 / _grad_prior_dim.dimension)
if self.prior_running_mean_per_layer:
for i in enc_aux_logits:
setattr(
self,
f"grad_prior_{i}",
rf.Parameter([_grad_prior_dim], auxiliary=True, initial=1.0 / _grad_prior_dim.dimension),
)

self.feature_batch_norm = None
if config.bool("feature_batch_norm", False):
Expand Down Expand Up @@ -2311,7 +2318,7 @@ def log_probs_wb_from_logits(self, logits: Tensor, *, aux_layer: Optional[int] =
logits, axis=self.wb_target_dim, out_dims=[self.target_dim, dummy_blank_feat_dim]
)
log_probs_wo_blank = rf.log_softmax(logits_wo_blank, axis=self.target_dim)
log_probs_wo_blank = self._maybe_apply_on_log_probs(log_probs_wo_blank)
log_probs_wo_blank = self._maybe_apply_on_log_probs(log_probs_wo_blank, aux_layer=aux_layer)
if self.blank_logit_shift:
logits_blank += self.blank_logit_shift
log_probs_blank = rf.log_sigmoid(logits_blank)
Expand All @@ -2338,7 +2345,7 @@ def _update_running_stats():

rf.cond(rf.get_run_ctx().train_flag, _update_running_stats, lambda: None)

log_probs = self._maybe_apply_on_log_probs(log_probs)
log_probs = self._maybe_apply_on_log_probs(log_probs, aux_layer=aux_layer)
if self.ctc_am_scale == 1 and self.ctc_prior_scale == 0: # fast path
return log_probs
log_probs_am = log_probs
Expand Down Expand Up @@ -2387,16 +2394,17 @@ def _update_running_stats():
log_probs -= log_prob_prior * self.ctc_prior_scale
return log_probs

def _maybe_apply_on_log_probs(self, log_probs: Tensor) -> Tensor:
def _maybe_apply_on_log_probs(self, log_probs: Tensor, *, aux_layer: Optional[int] = None) -> Tensor:
"""
:param log_probs: either with blank or without blank
:param aux_layer:
:return: log probs, maybe some smoothing applied (all on gradients so far, not on log probs itself)
"""
assert log_probs.feature_dim in (self.wb_target_dim, self.target_dim)
if not self.out_blank_separated:
assert log_probs.feature_dim == self.wb_target_dim

log_probs = self._maybe_apply_log_probs_normed_grad(log_probs)
log_probs = self._maybe_apply_log_probs_normed_grad(log_probs, aux_layer=aux_layer)

if self.ctc_label_smoothing_exclude_blank:
if self.out_blank_separated:
Expand All @@ -2412,7 +2420,7 @@ def _maybe_apply_on_log_probs(self, log_probs: Tensor) -> Tensor:

return log_probs

def _maybe_apply_log_probs_normed_grad(self, log_probs: Tensor) -> Tensor:
def _maybe_apply_log_probs_normed_grad(self, log_probs: Tensor, *, aux_layer: Optional[int] = None) -> Tensor:
if not self.log_prob_normed_grad_opts:
return log_probs

Expand Down Expand Up @@ -2440,7 +2448,10 @@ def _maybe_apply_log_probs_normed_grad(self, log_probs: Tensor) -> Tensor:
assert "prior_running_mean" not in opts # will be set by us here, and only when needed
if opts.get("prior_running_mean_momentum") is not None:
assert self.grad_prior is not None
opts["prior_running_mean"] = self.grad_prior.raw_tensor
grad_prior = self.grad_prior
if self.prior_running_mean_per_layer and aux_layer is not None:
grad_prior = getattr(self, f"grad_prior_{aux_layer}")
opts["prior_running_mean"] = grad_prior.raw_tensor

assert log_probs.batch_dim_axis is not None and log_probs.feature_dim_axis is not None
log_probs_ = log_probs.copy_template()
Expand Down
13 changes: 13 additions & 0 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,19 @@ def py():
},
}
},
"-lpNormedGradC05_11P07NExpL1_3": {
"prior_running_mean_per_layer": True,
"log_prob_normed_grad": {
"prior_running_mean_momentum": 0.001,
"func": {
"clamp_min": 0.5,
"clamp_max": 1.1,
"scale_type": "inv_num_labels",
"prior_exp": 0.7,
"prior_renorm": True,
},
},
},
}.items():
ctc_train_exp(
f"n12-spm10k-auxAED-b150k{name}",
Expand Down

0 comments on commit 3e4ec87

Please sign in to comment.