Skip to content

Commit

Permalink
grad prior running mean
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 3, 2025
1 parent d6078fb commit d9447de
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
22 changes: 20 additions & 2 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,11 +2097,13 @@ def __init__(
self.ctc_prior_type = config.value("ctc_prior_type", "batch")

static_prior = config.typed_value("static_prior")
self.static_prior = None
self.static_prior = None # in log prob, if set
if static_prior:
assert isinstance(static_prior, dict)
assert set(static_prior.keys()) == {"file", "type"}
v = numpy.loadtxt(static_prior["file"])
# The `type` is about what is stored in the file.
# We always store it in log prob here, so we potentially need to convert it.
if static_prior["type"] == "log_prob":
pass # already log prob
elif static_prior["type"] == "prob":
Expand Down Expand Up @@ -2142,6 +2144,17 @@ def __init__(
self.log_prob_normed_grad_exclude_blank = config.bool(
"log_prob_normed_grad_exclude_blank", self.out_blank_separated
)
self.grad_prior = None
if (
self.log_prob_normed_grad_opts
and self.log_prob_normed_grad_opts.get("prior_running_mean_momentum") is not None
):
# Note: We might want to use the static_prior for this purpose here.
# However, there are some differences, and it would probably just cause confusion and potential bugs.
# - static_prior is in log space, but here we want std prob space.
# - static_prior is supposed to be static, also with non_critical_for_restore=True.
_grad_prior_dim = self.target_dim if self.log_prob_normed_grad_exclude_blank else self.wb_target_dim
self.grad_prior = rf.Parameter([_grad_prior_dim], auxiliary=True, initial=1.0 / _grad_prior_dim.dimension)

self.feature_batch_norm = None
if config.bool("feature_batch_norm", False):
Expand Down Expand Up @@ -2357,10 +2370,15 @@ def _maybe_apply_log_probs_normed_grad(self, log_probs: Tensor) -> Tensor:
func_opts = opts.pop("func")
assert isinstance(func_opts, dict)
func_opts = func_opts.copy()
assert func_opts.get("class", "inv_prior") == "inv_prior" # only case for now
assert func_opts.get("class", "inv_prior") == "inv_prior" # only case for now: NormedGradientFuncInvPrior
func_opts.pop("class", None)
func = NormedGradientFuncInvPrior(**func_opts)

assert "prior_running_mean" not in opts # will be set by us here, and only when needed
if opts.get("prior_running_mean_momentum") is not None:
assert self.grad_prior is not None
opts["prior_running_mean"] = self.grad_prior.raw_tensor

assert log_probs.batch_dim_axis is not None and log_probs.feature_dim_axis is not None
log_probs_ = log_probs.copy_template()
log_probs_.raw_tensor = normed_gradient(
Expand Down
12 changes: 12 additions & 0 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,18 @@ def py():
}
}
},
"-lpNormedGradC05_11P07NExp1_3": {
"log_prob_normed_grad": {
"prior_running_mean_momentum": 0.001,
"func": {
"clamp_min": 0.5,
"clamp_max": 1.1,
"scale_type": "inv_num_labels",
"prior_exp": 0.7,
"prior_renorm": True,
},
}
},
}.items():
ctc_train_exp(
f"n12-spm10k-auxAED-b150k{name}",
Expand Down

0 comments on commit d9447de

Please sign in to comment.