diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py index afca16694..86af14ded 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py @@ -2170,6 +2170,13 @@ def __init__( # - static_prior is supposed to be static, also with non_critical_for_restore=True. _grad_prior_dim = self.target_dim if self.log_prob_normed_grad_exclude_blank else self.wb_target_dim self.grad_prior = rf.Parameter([_grad_prior_dim], auxiliary=True, initial=1.0 / _grad_prior_dim.dimension) + if self.prior_running_mean_per_layer: + for i in enc_aux_logits: + setattr( + self, + f"grad_prior_{i}", + rf.Parameter([_grad_prior_dim], auxiliary=True, initial=1.0 / _grad_prior_dim.dimension), + ) self.feature_batch_norm = None if config.bool("feature_batch_norm", False): @@ -2311,7 +2318,7 @@ def log_probs_wb_from_logits(self, logits: Tensor, *, aux_layer: Optional[int] = logits, axis=self.wb_target_dim, out_dims=[self.target_dim, dummy_blank_feat_dim] ) log_probs_wo_blank = rf.log_softmax(logits_wo_blank, axis=self.target_dim) - log_probs_wo_blank = self._maybe_apply_on_log_probs(log_probs_wo_blank) + log_probs_wo_blank = self._maybe_apply_on_log_probs(log_probs_wo_blank, aux_layer=aux_layer) if self.blank_logit_shift: logits_blank += self.blank_logit_shift log_probs_blank = rf.log_sigmoid(logits_blank) @@ -2338,7 +2345,7 @@ def _update_running_stats(): rf.cond(rf.get_run_ctx().train_flag, _update_running_stats, lambda: None) - log_probs = self._maybe_apply_on_log_probs(log_probs) + log_probs = self._maybe_apply_on_log_probs(log_probs, aux_layer=aux_layer) if self.ctc_am_scale == 1 and self.ctc_prior_scale == 0: # fast path return log_probs log_probs_am = log_probs @@ -2387,16 +2394,17 @@ def _update_running_stats(): log_probs -= log_prob_prior * self.ctc_prior_scale return log_probs - def _maybe_apply_on_log_probs(self, log_probs: Tensor) -> Tensor: + def _maybe_apply_on_log_probs(self, log_probs: Tensor, *, aux_layer: Optional[int] = None) -> Tensor: """ :param log_probs: either with blank or without blank + :param aux_layer: :return: log probs, maybe some smoothing applied (all on gradients so far, not on log probs itself) """ assert log_probs.feature_dim in (self.wb_target_dim, self.target_dim) if not self.out_blank_separated: assert log_probs.feature_dim == self.wb_target_dim - log_probs = self._maybe_apply_log_probs_normed_grad(log_probs) + log_probs = self._maybe_apply_log_probs_normed_grad(log_probs, aux_layer=aux_layer) if self.ctc_label_smoothing_exclude_blank: if self.out_blank_separated: @@ -2412,7 +2420,7 @@ def _maybe_apply_on_log_probs(self, log_probs: Tensor) -> Tensor: return log_probs - def _maybe_apply_log_probs_normed_grad(self, log_probs: Tensor) -> Tensor: + def _maybe_apply_log_probs_normed_grad(self, log_probs: Tensor, *, aux_layer: Optional[int] = None) -> Tensor: if not self.log_prob_normed_grad_opts: return log_probs @@ -2440,7 +2448,10 @@ def _maybe_apply_log_probs_normed_grad(self, log_probs: Tensor) -> Tensor: assert "prior_running_mean" not in opts # will be set by us here, and only when needed if opts.get("prior_running_mean_momentum") is not None: assert self.grad_prior is not None - opts["prior_running_mean"] = self.grad_prior.raw_tensor + grad_prior = self.grad_prior + if self.prior_running_mean_per_layer and aux_layer is not None: + grad_prior = getattr(self, f"grad_prior_{aux_layer}") + opts["prior_running_mean"] = grad_prior.raw_tensor assert log_probs.batch_dim_axis is not None and log_probs.feature_dim_axis is not None log_probs_ = log_probs.copy_template() diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py index db2150d74..04f40c60f 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py @@ -676,6 +676,19 @@ def py(): }, } }, + "-lpNormedGradC05_11P07NExpL1_3": { + "prior_running_mean_per_layer": True, + "log_prob_normed_grad": { + "prior_running_mean_momentum": 0.001, + "func": { + "clamp_min": 0.5, + "clamp_max": 1.1, + "scale_type": "inv_num_labels", + "prior_exp": 0.7, + "prior_renorm": True, + }, + }, + }, }.items(): ctc_train_exp( f"n12-spm10k-auxAED-b150k{name}",