From d9447de3bcda4cf8de369a601ee822c00c596a59 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 3 Jan 2025 02:30:05 +0100 Subject: [PATCH] grad prior running mean --- .../exp2024_04_23_baselines/ctc.py | 22 +++++++++++++++++-- .../exp2024_04_23_baselines/ctc_claix2023.py | 12 ++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py index 7c8a86b70..6ddcbf750 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py @@ -2097,11 +2097,13 @@ def __init__( self.ctc_prior_type = config.value("ctc_prior_type", "batch") static_prior = config.typed_value("static_prior") - self.static_prior = None + self.static_prior = None # in log prob, if set if static_prior: assert isinstance(static_prior, dict) assert set(static_prior.keys()) == {"file", "type"} v = numpy.loadtxt(static_prior["file"]) + # The `type` is about what is stored in the file. + # We always store it in log prob here, so we potentially need to convert it. if static_prior["type"] == "log_prob": pass # already log prob elif static_prior["type"] == "prob": @@ -2142,6 +2144,17 @@ def __init__( self.log_prob_normed_grad_exclude_blank = config.bool( "log_prob_normed_grad_exclude_blank", self.out_blank_separated ) + self.grad_prior = None + if ( + self.log_prob_normed_grad_opts + and self.log_prob_normed_grad_opts.get("prior_running_mean_momentum") is not None + ): + # Note: We might want to use the static_prior for this purpose here. + # However, there are some differences, and it would probably just cause confusion and potential bugs. + # - static_prior is in log space, but here we want std prob space. + # - static_prior is supposed to be static, also with non_critical_for_restore=True. + _grad_prior_dim = self.target_dim if self.log_prob_normed_grad_exclude_blank else self.wb_target_dim + self.grad_prior = rf.Parameter([_grad_prior_dim], auxiliary=True, initial=1.0 / _grad_prior_dim.dimension) self.feature_batch_norm = None if config.bool("feature_batch_norm", False): @@ -2357,10 +2370,15 @@ def _maybe_apply_log_probs_normed_grad(self, log_probs: Tensor) -> Tensor: func_opts = opts.pop("func") assert isinstance(func_opts, dict) func_opts = func_opts.copy() - assert func_opts.get("class", "inv_prior") == "inv_prior" # only case for now + assert func_opts.get("class", "inv_prior") == "inv_prior" # only case for now: NormedGradientFuncInvPrior func_opts.pop("class", None) func = NormedGradientFuncInvPrior(**func_opts) + assert "prior_running_mean" not in opts # will be set by us here, and only when needed + if opts.get("prior_running_mean_momentum") is not None: + assert self.grad_prior is not None + opts["prior_running_mean"] = self.grad_prior.raw_tensor + assert log_probs.batch_dim_axis is not None and log_probs.feature_dim_axis is not None log_probs_ = log_probs.copy_template() log_probs_.raw_tensor = normed_gradient( diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py index d9c20eda6..7bbaa0fa7 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py @@ -651,6 +651,18 @@ def py(): } } }, + "-lpNormedGradC05_11P07NExp1_3": { + "log_prob_normed_grad": { + "prior_running_mean_momentum": 0.001, + "func": { + "clamp_min": 0.5, + "clamp_max": 1.1, + "scale_type": "inv_num_labels", + "prior_exp": 0.7, + "prior_renorm": True, + }, + } + }, }.items(): ctc_train_exp( f"n12-spm10k-auxAED-b150k{name}",