diff --git a/users/mueller/experiments/ctc_baseline/ctc.py b/users/mueller/experiments/ctc_baseline/ctc.py index 598200dd7..b7ea2eda1 100644 --- a/users/mueller/experiments/ctc_baseline/ctc.py +++ b/users/mueller/experiments/ctc_baseline/ctc.py @@ -71,6 +71,7 @@ def py(): blank_prior = True prior_gradient = False LM_order = 2 + top_k = 1 self_train_subset = 18000 if train_small: @@ -149,13 +150,7 @@ def py(): } if self_training_rounds > 0 else None for am, lm, prior in [ - (0.5, 0.5, 0.5), - # (0.5, 0.3, 0.5), - # (0.5, 0.2, 0.5), - # (0.5, 0.1, 0.5), - # (0.5, 0.05, 0.5), - # (0.5, 0.0, 0.5), - # (0.3, 0.2, 0.5), + (1.0, 0.0, 0.2) ]: if use_sum_criterion: training_scales = { @@ -170,6 +165,7 @@ def py(): sum_str = f"-full_sum" + \ (f"_p{str(training_scales['prior']).replace('.', '')}_l{str(training_scales['lm']).replace('.', '')}_a{str(training_scales['am']).replace('.', '')}" if training_scales else "") + \ (f"_LMorder{LM_order}" if LM_order > 2 else "") + \ + (f"_topK{top_k}" if top_k > 0 else "") + \ ("_wo_hor_pr" if not horizontal_prior else "") + \ ("_wo_blank_pr" if not blank_prior else "") + \ ("_wo_pr_grad" if not prior_gradient else "") @@ -201,6 +197,7 @@ def py(): blank_prior=blank_prior, prior_gradient=prior_gradient, LM_order=LM_order, + top_k=top_k, training_scales=training_scales if use_sum_criterion else None, self_train_subset=self_train_subset, ) @@ -242,6 +239,7 @@ def train_exp( blank_prior: bool = True, prior_gradient: bool = True, LM_order: int = 2, + top_k: int = 0, training_scales: Optional[Dict[str, float]] = None, self_train_subset: Optional[int] = None, ) -> Optional[ModelWithCheckpoints]: @@ -353,6 +351,8 @@ def train_exp( config_self["prior_scale"] = training_scales["prior"] if not prior_gradient: config_self["prior_gradient"] = prior_gradient + if top_k > 0: + config_self["top_k"] = top_k # When testing on a smaller subset we only want one gpu if self_train_subset is not None: @@ -966,12 +966,17 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten horizontal_prior = config.bool("horizontal_prior", True) blank_prior = config.bool("blank_prior", True) prior_gradient = config.bool("prior_gradient", True) + top_k = config.int("top_k", 0) use_prior = prior_scale > 0.0 if data.feature_dim and data.feature_dim.dimension == 1: data = rf.squeeze(data, axis=data.feature_dim) assert not data.feature_dim # raw audio + if am_scale == 0.7: + print("Data", data) + print("Batch", data.batch) + with uopen(lm_path, "rb") as f: lm = torch.load(f, map_location=data.device) @@ -1001,6 +1006,7 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten log_lm_probs=lm, log_prior=aux_log_prior, input_lengths=enc_spatial_dim.dyn_size_ext.raw_tensor, + top_k=top_k, LM_order=lm_order, am_scale=am_scale, lm_scale=lm_scale, @@ -1035,6 +1041,7 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten log_lm_probs=lm, log_prior=log_prior, input_lengths=enc_spatial_dim.dyn_size_ext.raw_tensor, + top_k=top_k, LM_order=lm_order, am_scale=am_scale, lm_scale=lm_scale, diff --git a/users/mueller/experiments/ctc_baseline/sum_criterion.py b/users/mueller/experiments/ctc_baseline/sum_criterion.py index 1c36cead8..820841c23 100644 --- a/users/mueller/experiments/ctc_baseline/sum_criterion.py +++ b/users/mueller/experiments/ctc_baseline/sum_criterion.py @@ -23,6 +23,7 @@ def sum_loss( unk_idx: int = 1, log_zero: float = float("-inf"), device: torch.device = torch.device("cpu"), + print_best_path_for_idx: list[int] = [], ): """ Sum criterion training for CTC, given by @@ -83,7 +84,10 @@ def sum_loss( max_audio_time, batch_size, n_out = log_probs.shape # scaled log am and lm probs log_probs = am_scale * log_probs - log_lm_probs = lm_scale * log_lm_probs + if lm_scale == 0.0: + log_lm_probs = torch.zeros_like(log_lm_probs) + else: + log_lm_probs = lm_scale * log_lm_probs if use_prior: log_prior = prior_scale * log_prior @@ -140,6 +144,12 @@ def sum_loss( log_q = log_q_label if top_k > 0: topk_scores, topk_idx = torch.topk(log_q, top_k, dim=-1, sorted=False) + if print_best_path_for_idx: + with torch.no_grad(): + best_path_print = {} + max_val, max_idx = torch.max(log_q, dim=-1) + for idx in print_best_path_for_idx: + best_path_print[idx] = {"str": f"{max_idx[idx] + 2}", "score": "{:.2f}".format(max_val[idx].tolist()), "AM": log_probs[0][idx].tolist()[max_idx[idx] + 2]} log_lm_probs_wo_eos = log_lm_probs[out_idx_vocab][:, out_idx_vocab].fill_diagonal_(log_zero) for t in range(1, max_audio_time): @@ -201,6 +211,13 @@ def sum_loss( if top_k > 0: topk_scores, topk_idx = torch.topk(log_q, top_k, dim=-1, sorted=False) + if print_best_path_for_idx: + with torch.no_grad(): + max_val, max_idx = torch.max(log_q, dim=-1) + for idx in print_best_path_for_idx: + best_path_print[idx]["str"] += f" {max_idx[idx] + 2}" + best_path_print[idx]["score"] += " {:.2f}".format(max_val[idx].tolist()) # / (t+1) + best_path_print[idx]["AM"] += log_probs[t][idx].tolist()[max_idx[idx] + 2] torch.cuda.empty_cache() @@ -209,6 +226,10 @@ def sum_loss( log_q = topk_scores + log_lm_probs[out_idx_vocab, eos_symbol].unsqueeze(0).expand(batch_size, -1).gather(-1, topk_idx) else: log_q = log_q + log_lm_probs[out_idx_vocab, eos_symbol].unsqueeze(0) + if print_best_path_for_idx: + with torch.no_grad(): + for idx in print_best_path_for_idx: + print(f"Best path for {idx}: {best_path_print[idx]['str']}\nScore: {best_path_print[idx]['score']}\nAM: {-best_path_print[idx]['AM']}") # sum over the vocab dimension sum_score = safe_logsumexp(log_q, dim=-1) @@ -681,45 +702,63 @@ def test(): # prior = _calc_log_prior(am, length) # am = am.permute(1, 0, 2) - am = ag(am, "AM", False) - prior = ag(prior, "prior", False) + # am = ag(am, "AM", False) + # prior = ag(prior, "prior", False) - # loss = sum_loss( - # log_probs=am, - # log_lm_probs=lm, - # log_prior=prior, - # input_lengths=length, - # LM_order=2, - # am_scale=1.0, - # lm_scale=1.9, - # prior_scale=0.2, - # horizontal_prior=True, - # blank_idx=184, - # eos_idx=0, - # ) - loss = sum_loss_k( + loss = sum_loss( log_probs=am, log_lm_probs=lm, log_prior=prior, input_lengths=length, - top_k=1, + top_k = 1, LM_order=2, am_scale=1.0, - lm_scale=1.9, - prior_scale=0.2, + lm_scale=0.0, + prior_scale=0.0, horizontal_prior=True, + blank_prior=True, blank_idx=184, eos_idx=0, + print_best_path_for_idx=[0] ) + print("OUT", loss[0].tolist()) l += (loss / frames).mean() - del loss, am, prior - torch.cuda.empty_cache() - print(time.time() - s) - l.backward(torch.ones_like(l, device=device)) + # del loss, am, prior + # torch.cuda.empty_cache() + # print(time.time() - s) + + # targets = torch.tensor([55, 148, 178, 108, 179, 126, 110, 103, 9, 154, 84, 162, 159, 83, 153, 33, 106, 9, 131, 46, 63, 15, 162, 94, 0, 111, 121, 29, 121, 21, 151, 18, 4, 159, 118, 86, 129, 18, 13, 170, 151, 81, 77, 53, 165, 57, 134, 63, 103, 110, 47, 35, 145, 18, 34, 66, 42, 96, 139, 16, 138, 156, 1, 63, 103, 95, 149, 111, 83, 34, 113, 158, 39, 166, 34, 123, 26, 148, 134, 148, 168, 177, 18, 23, 164, 69, 145, 93, 166, 174, 162, 36, 95, 116, 123, 74, 124, 70]) + # targets = targets + 2 + targets = torch.tensor( + [ 57, 150, 180, 110, 107, 128, 112, 105, 11, 156, 86, 164, 161, 85, + 155, 35, 108, 11, 133, 48, 133, 17, 164, 96, 2, 113, 123, 31, + 123, 23, 153, 20, 6, 161, 120, 88, 131, 20, 15, 99, 153, 58, + 119, 1, 88, 59, 136, 65, 105, 99, 122, 37, 147, 20, 36, 68, + 44, 98, 141, 18, 1, 158, 3, 65, 105, 97, 151, 113, 85, 36, + 115, 160, 83, 168, 36, 125, 28, 150, 136, 90, 170, 179, + 20, 25, 166, 71, 147, 95, 168, 176, 164, 38, 97, 118, 125, 76, + 43, 72] + ) + # greedy_probs, greedy_idx = torch.max(am[:, 0:1], dim=-1) + # print(greedy_idx.squeeze(-1)) + targets = targets.unsqueeze(0) + target_lengths = torch.tensor([targets.size(1)]) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=am[:, 0:1], + targets=targets, + input_lengths=length[0:1], + target_lengths=target_lengths, + blank=184, + reduction="none" + ) + print(ctc_loss) + + + # l.backward(torch.ones_like(l, device=device)) e1 = time.time() - print(f"Sum loss took {time.strftime('%H:%M:%S', time.gmtime(e1-s1))}: {l}") # 5:00 mins + # print(f"Sum loss took {time.strftime('%H:%M:%S', time.gmtime(e1-s1))}: {l}") # 5:00 mins # s2 = time.time()