diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 991bba515..682fdc078 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -160,21 +160,37 @@ def get_args(): default=0.0, help='lm scale for hlg attention rescore decode') - parser.add_argument( - '--context_bias_mode', - type=str, - default='', - help='''Context bias mode, selectable from the following - option: decoding-graph、deep-biasing''') + parser.add_argument('--context_bias_mode', + type=str, + default='', + help='''Context bias mode, selectable from the + following option: context_graph, + deep_biasing''') parser.add_argument('--context_list_path', type=str, default='', help='Context list path') parser.add_argument('--context_graph_score', type=float, - default=0.0, + default=2.0, help='''The higher the score, the greater the degree of - bias using decoding-graph for biasing''') + bias using context_graph for biasing''') + parser.add_argument('--deep_biasing_score', + type=float, + default=1.0, + help='''The higher the score, the greater the degree of + bias using deep_biasing for biasing''') + parser.add_argument('--context_filtering', + action='store_true', + help='''Reduce the size of the context list through + filtering to enhance the effect of context + biasing''') + parser.add_argument('--context_filtering_threshold', + type=float, + default=-4.0, + help='''The threshold for context filtering, the larger + the value, the closer it is to 0, and the fewer + remaining context phrases are filtered''') args = parser.parse_args() print(args) @@ -225,6 +241,9 @@ def main(): test_conf['batch_conf']['batch_size'] = args.batch_size non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + if 'context_conf' in test_conf: + del test_conf['context_conf'] + test_dataset = Dataset(args.data_type, args.test_data, symbol_table, @@ -256,17 +275,25 @@ def main(): paraformer_beam_search = None context_graph = None - if 'decoding-graph' in args.context_bias_mode: + if args.context_bias_mode != '': context_graph = ContextGraph(args.context_list_path, symbol_table, args.bpe_model, args.context_graph_score) + context_graph.context_filtering = args.context_filtering + context_graph.filter_threshold = args.context_filtering_threshold + if 'deep_biasing' in args.context_bias_mode: + context_graph.deep_biasing = True + context_graph.deep_biasing_score = args.deep_biasing_score + if 'context_graph' in args.context_bias_mode: + context_graph.graph_biasing = True with torch.no_grad(), open(args.result_file, 'w') as fout: for batch_idx, batch in enumerate(test_data_loader): - keys, feats, target, feats_lengths, target_lengths = batch + keys, feats, target, feats_lengths, target_lengths, _ = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) + if args.mode == 'attention': hyps, _ = model.recognize( feats, @@ -274,7 +301,7 @@ def main(): beam_size=args.beam_size, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, - simulate_streaming=args.simulate_streaming) + simulate_streaming=args.simulate_streaming,) hyps = [hyp.tolist() for hyp in hyps] elif args.mode == 'ctc_greedy_search': hyps, _ = model.ctc_greedy_search( @@ -282,7 +309,8 @@ def main(): feats_lengths, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, - simulate_streaming=args.simulate_streaming) + simulate_streaming=args.simulate_streaming, + context_graph=context_graph) elif args.mode == 'rnnt_greedy_search': assert (feats.size(0) == 1) assert 'predictor' in configs diff --git a/wenet/bin/train.py b/wenet/bin/train.py index da9a6f6bb..eba63dfb3 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -276,6 +276,17 @@ def main(): num_params = sum(p.numel() for p in model.parameters()) print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None # noqa + if 'context_module_conf' in configs: + # Freeze other parts of the model during training context bias module + for p in model.parameters(): + p.requires_grad = False + for p in model.context_module.parameters(): + p.requires_grad = True + for p in model.context_module.context_decoder_ctc_linear.parameters(): + p.requires_grad = False + # Turn off dynamic chunk because it will affect the training of bias + model.encoder.use_dynamic_chunk = False + # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 6d799b5b5..5280a2578 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -189,5 +189,11 @@ def Dataset(data_type, batch_conf = conf.get('batch_conf', {}) dataset = Processor(dataset, processor.batch, **batch_conf) + + context_conf = conf.get('context_conf', {}) + if len(context_conf) != 0: + dataset = Processor(dataset, processor.context_sampling, + symbol_table, **context_conf) + dataset = Processor(dataset, processor.padding) return dataset diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index b69ceca85..c37ac7c10 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -610,6 +610,99 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): logging.fatal('Unsupported batch type {}'.format(batch_type)) +def context_sampling(data, + symbol_table, + len_min, + len_max, + utt_num_context, + batch_num_context, + ): + """Perform context sampling by randomly selecting context phrases from the + utterance to obtain a context list for the entire batch + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[List[{key, feat, label, context_list}]] + """ + rev_symbol_table = {} + for token in symbol_table: + rev_symbol_table[symbol_table[token]] = token + context_list_over_all = [] + for sample in data: + batch_label = [sample[i]['label'] for i in range(len(sample))] + context_list = [] + for utt_label in batch_label: + st_index_list = [] + for i in range(len(utt_label)): + if '▁' not in symbol_table: + st_index_list.append(i) + elif rev_symbol_table[utt_label[i]][0] == '▁': + st_index_list.append(i) + st_index_list.append(len(utt_label)) + + st_select = [] + en_select = [] + for _ in range(0, utt_num_context): + random_len = random.randint(min(len(st_index_list) - 1, len_min), + min(len(st_index_list) - 1, len_max)) + random_index = random.randint(0, len(st_index_list) - + random_len - 1) + st_index = st_index_list[random_index] + en_index = st_index_list[random_index + random_len] + context_label = utt_label[st_index: en_index] + cross_flag = True + for i in range(len(st_select)): + if st_index >= st_select[i] and st_index < en_select[i]: + cross_flag = False + elif en_index > st_select[i] and en_index <= en_select[i]: + cross_flag = False + elif st_index < st_select[i] and en_index > en_select[i]: + cross_flag = False + if cross_flag: + context_list.append(context_label) + st_select.append(st_index) + en_select.append(en_index) + + if len(context_list) > batch_num_context: + context_list_over_all = context_list + elif len(context_list) + len(context_list_over_all) > batch_num_context: + context_list_over_all.extend(context_list) + context_list_over_all = context_list_over_all[-batch_num_context:] + else: + context_list_over_all.extend(context_list) + context_list = context_list_over_all.copy() + context_list.insert(0, [0]) + for i in range(len(context_list)): + context_list[i] = torch.tensor(context_list[i], dtype=torch.int32) + sample[0]['context_list'] = context_list + yield sample + + +def context_label_generate(label, context_list): + """ Generate context labels corresponding to the utterances based on + the context list + """ + context_labels = [] + for x in label: + cur_len = len(x) + context_label = [] + count = 0 + for i in range(cur_len): + for j in range(1, len(context_list)): + if i + len(context_list[j]) > cur_len: + continue + if x[i:i + len(context_list[j])].equal(context_list[j]): + count = max(count, len(context_list[j])) + if count > 0: + context_label.append(x[i]) + count -= 1 + context_label = torch.tensor(context_label, dtype=torch.int64) + context_labels.append(context_label) + return context_labels + + def padding(data): """ Padding the data into training data @@ -641,5 +734,26 @@ def padding(data): batch_first=True, padding_value=-1) - yield (sorted_keys, padded_feats, padding_labels, feats_lengths, - label_lengths) + if 'context_list' not in sample[0]: + yield (sorted_keys, padded_feats, padding_labels, feats_lengths, + label_lengths, []) + else: + context_lists = sample[0]['context_list'] + context_list_lengths = \ + torch.tensor([x.size(0) for x in context_lists], dtype=torch.int32) + padding_context_lists = pad_sequence(context_lists, + batch_first=True, + padding_value=-1) + + sorted_context_labels = context_label_generate(sorted_labels, + context_lists) + context_label_lengths = \ + torch.tensor([x.size(0) for x in sorted_context_labels], + dtype=torch.int32) + padding_context_labels = pad_sequence(sorted_context_labels, + batch_first=True, + padding_value=-1) + yield (sorted_keys, padded_feats, padding_labels, + feats_lengths, label_lengths, + [padding_context_lists, padding_context_labels, + context_list_lengths, context_label_lengths]) diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index cf0187b0f..364ef9fdd 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -18,6 +18,7 @@ from wenet.transformer.ctc import CTC from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.transformer.context_module import ContextModule from wenet.utils.common import (IGNORE_ID, add_blank, add_sos_eos, reverse_pad_list) @@ -35,6 +36,8 @@ def __init__( BiTransformerDecoder]] = None, ctc: Optional[CTC] = None, ctc_weight: float = 0, + context_module: ContextModule = None, + bias_weight: float = 0, ignore_id: int = IGNORE_ID, reverse_weight: float = 0.0, lsm_weight: float = 0.0, @@ -50,8 +53,8 @@ def __init__( ) -> None: assert attention_weight + ctc_weight + transducer_weight == 1.0 super().__init__(vocab_size, encoder, attention_decoder, ctc, - ctc_weight, ignore_id, reverse_weight, lsm_weight, - length_normalized_loss) + context_module, ctc_weight, ignore_id, reverse_weight, + lsm_weight, length_normalized_loss, '', bias_weight) self.blank = blank self.transducer_weight = transducer_weight @@ -93,6 +96,7 @@ def forward( text: torch.Tensor, text_lengths: torch.Tensor, steps: int = 0, + context_data: List[torch.Tensor] = None, ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + predictor + joint + loss @@ -101,6 +105,8 @@ def forward( speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) + context_data: [context_list, context_label, + context_list_lengths, context_label_lengths] """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified @@ -112,6 +118,27 @@ def forward( encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out_lens = encoder_mask.squeeze(1).sum(1) + # Context biasing branch + loss_bias: Optional[torch.Tensor] = None + if self.context_module is not None: + assert len(context_data) == 4 + context_list = context_data[0] + context_label = context_data[1] + context_list_lengths = context_data[2] + context_label_lengths = context_data[3] + context_emb = self.context_module. \ + forward_context_emb(context_list, context_list_lengths) + encoder_out, bias_out = self.context_module(context_emb, + encoder_out) + bias_out = bias_out.transpose(0, 1).log_softmax(2) + loss_bias = self.context_module.bias_loss(bias_out, + context_label, + encoder_out_lens, + context_label_lengths + ) / bias_out.size(1) + else: + loss_bias = None + # compute_loss loss_rnnt = compute_loss( self, @@ -143,12 +170,15 @@ def forward( loss = loss + self.ctc_weight * loss_ctc.sum() if loss_att is not None: loss = loss + self.attention_decoder_weight * loss_att.sum() + if loss_bias is not None: + loss = loss + self.bias_weight * loss_bias.sum() # NOTE: 'loss' must be in dict return { 'loss': loss, 'loss_att': loss_att, 'loss_ctc': loss_ctc, 'loss_rnnt': loss_rnnt, + 'loss_bias': loss_bias, } def init_bs(self): diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index f593d067b..a651f591a 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -35,6 +35,7 @@ from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.encoder import TransformerEncoder from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.transformer.context_module import ContextModule from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add, remove_duplicates_and_blank, th_accuracy, reverse_pad_list) @@ -51,12 +52,14 @@ def __init__( encoder: TransformerEncoder, decoder: TransformerDecoder, ctc: CTC, + context_module: ContextModule = None, ctc_weight: float = 0.5, ignore_id: int = IGNORE_ID, reverse_weight: float = 0.0, lsm_weight: float = 0.0, length_normalized_loss: bool = False, lfmmi_dir: str = '', + bias_weight: float = 0.03, ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight @@ -68,10 +71,12 @@ def __init__( self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.reverse_weight = reverse_weight + self.bias_weight = bias_weight self.encoder = encoder self.decoder = decoder self.ctc = ctc + self.context_module = context_module self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, @@ -88,6 +93,7 @@ def forward( speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, + context_data: List[torch.Tensor], ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + Decoder + Calc loss @@ -96,6 +102,8 @@ def forward( speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) + context_data: [context_list, context_label, + context_list_lengths, context_label_lengths] """ assert text_lengths.dim() == 1, text_lengths.shape @@ -107,6 +115,26 @@ def forward( encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out_lens = encoder_mask.squeeze(1).sum(1) + # 1a. Context biasing branch + loss_bias: Optional[torch.Tensor] = None + if self.context_module is not None: + assert len(context_data) == 4 + context_list = context_data[0] + context_label = context_data[1] + context_list_lengths = context_data[2] + context_label_lengths = context_data[3] + context_emb = self.context_module. \ + forward_context_emb(context_list, context_list_lengths) + encoder_out, bias_out = self.context_module(context_emb, + encoder_out) + bias_out = bias_out.transpose(0, 1).log_softmax(2) + loss_bias = self.context_module.bias_loss(bias_out, context_label, + encoder_out_lens, + context_label_lengths) + loss_bias /= bias_out.size(1) + else: + loss_bias = None + # 2a. Attention-decoder branch if self.ctc_weight != 1.0: loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, @@ -130,9 +158,13 @@ def forward( elif loss_att is None: loss = loss_ctc else: - loss = self.ctc_weight * loss_ctc + (1 - - self.ctc_weight) * loss_att - return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc} + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * \ + loss_att + if loss is not None and loss_bias is not None: + loss = loss + self.bias_weight * loss_bias + + return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc, + "loss_bias": loss_bias} def _calc_att_loss( self, @@ -314,6 +346,7 @@ def ctc_greedy_search( decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, + context_graph: ContextGraph = None, ) -> List[List[int]]: """ Apply CTC greedy search @@ -339,6 +372,11 @@ def ctc_greedy_search( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) + + if context_graph is not None and context_graph.deep_biasing: + encoder_out = context_graph.forward_deep_biasing( + encoder_out, self.context_module, self.ctc) + maxlen = encoder_out.size(1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) ctc_probs = self.ctc.log_softmax( @@ -392,6 +430,11 @@ def _ctc_prefix_beam_search( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) + + if context_graph is not None and context_graph.deep_biasing: + encoder_out = context_graph.forward_deep_biasing( + encoder_out, self.context_module, self.ctc) + maxlen = encoder_out.size(1) ctc_probs = self.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) @@ -426,7 +469,7 @@ def _ctc_prefix_beam_search( n_prefix = prefix + (s, ) n_pb, n_pnb, _, _ = next_hyps[n_prefix] new_c_state, new_c_score = 0, 0 - if context_graph is not None: + if context_graph is not None and context_graph.graph_biasing: new_c_state, new_c_score = context_graph. \ find_next_state(c_state, s) n_pnb = log_add([n_pnb, pb + ps]) @@ -436,7 +479,7 @@ def _ctc_prefix_beam_search( n_prefix = prefix + (s, ) n_pb, n_pnb, _, _ = next_hyps[n_prefix] new_c_state, new_c_score = 0, 0 - if context_graph is not None: + if context_graph is not None and context_graph.graph_biasing: new_c_state, new_c_score = context_graph. \ find_next_state(c_state, s) n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) diff --git a/wenet/transformer/context_module.py b/wenet/transformer/context_module.py new file mode 100644 index 000000000..d12d9d511 --- /dev/null +++ b/wenet/transformer/context_module.py @@ -0,0 +1,144 @@ +# Copyright (c) 2023 ASLP@NWPU (authors: Kaixun Huang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import torch +import torch.nn as nn +from typing import Tuple +from wenet.transformer.attention import MultiHeadedAttention + + +class BLSTM(torch.nn.Module): + """Context encoder, encoding unequal-length context phrases + into equal-length embedding representations. + """ + + def __init__(self, + vocab_size, + embedding_size, + num_layers, + dropout=0.0): + super(BLSTM, self).__init__() + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.word_embedding = torch.nn.Embedding( + self.vocab_size, self.embedding_size) + + self.sen_rnn = torch.nn.LSTM(input_size=self.embedding_size, + hidden_size=self.embedding_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + bidirectional=True) + + def forward(self, sen_batch, sen_lengths): + sen_batch = torch.clamp(sen_batch, 0) + sen_batch = self.word_embedding(sen_batch) + pack_seq = torch.nn.utils.rnn.pack_padded_sequence( + sen_batch, sen_lengths.to('cpu').type(torch.int32), + batch_first=True, enforce_sorted=False) + _, last_state = self.sen_rnn(pack_seq) + laste_h = last_state[0] + laste_c = last_state[1] + state = torch.cat([laste_h[-1, :, :], laste_h[-2, :, :], + laste_c[-1, :, :], laste_c[-2, :, :]], dim=-1) + return state + + +class ContextModule(torch.nn.Module): + """Context module, Using context information for deep contextual bias + + During the training process, the original parameters of the ASR model + are frozen, and only the parameters of context module are trained. + + Args: + vocab_size (int): vocabulary size + embedding_size (int): number of ASR encoder projection units + encoder_layers (int): number of context encoder layers + attention_heads (int): number of heads in the biasing layer + """ + def __init__( + self, + vocab_size: int, + embedding_size: int, + encoder_layers: int = 2, + attention_heads: int = 4, + dropout_rate: float = 0.0, + ): + super().__init__() + self.embedding_size = embedding_size + self.encoder_layers = encoder_layers + self.vocab_size = vocab_size + self.attention_heads = attention_heads + self.dropout_rate = dropout_rate + + self.context_extractor = BLSTM(self.vocab_size, self.embedding_size, + self.encoder_layers) + self.context_encoder = nn.Sequential( + nn.Linear(self.embedding_size * 4, self.embedding_size), + nn.LayerNorm(self.embedding_size) + ) + + self.biasing_layer = MultiHeadedAttention( + n_head=self.attention_heads, + n_feat=self.embedding_size, + dropout_rate=self.dropout_rate + ) + + self.combiner = nn.Linear(self.embedding_size, self.embedding_size) + self.norm_aft_combiner = nn.LayerNorm(self.embedding_size) + + self.context_decoder = nn.Sequential( + nn.Linear(self.embedding_size, self.embedding_size), + nn.LayerNorm(self.embedding_size), + nn.ReLU(inplace=True), + ) + self.context_decoder_ctc_linear = nn.Linear(self.embedding_size, + self.vocab_size) + + self.bias_loss = torch.nn.CTCLoss(reduction="sum", zero_infinity=True) + + def forward_context_emb(self, + context_list: torch.Tensor, + context_lengths: torch.Tensor + ) -> torch.Tensor: + """Extracting context embeddings + """ + context_emb = self.context_extractor(context_list, context_lengths) + context_emb = self.context_encoder(context_emb.unsqueeze(0)) + return context_emb + + def forward(self, + context_emb: torch.Tensor, + encoder_out: torch.Tensor, + biasing_score: float = 1.0, + recognize: bool = False) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Using context embeddings for deep biasing. + + Args: + biasing_score (int): degree of context biasing + recognize (bool): no context decoder computation if True + """ + context_emb = context_emb.expand(encoder_out.shape[0], -1, -1) + context_emb, _ = self.biasing_layer(encoder_out, context_emb, + context_emb) + encoder_bias_out = \ + self.norm_aft_combiner(encoder_out + + self.combiner(context_emb) * biasing_score) + if recognize: + return encoder_bias_out, torch.tensor(0.0) + bias_out = self.context_decoder(context_emb) + bias_out = self.context_decoder_ctc_linear(bias_out) + return encoder_bias_out, bias_out diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index 8e0c413c7..aad9b6fa1 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -31,6 +31,10 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) checkpoint = torch.load(path, map_location='cpu') model.load_state_dict(checkpoint, strict=False) + if hasattr(model, 'context_module') and \ + hasattr(model.context_module, 'context_decoder_ctc_linear'): + model.context_module.context_decoder_ctc_linear \ + .load_state_dict(model.ctc.ctc_lo.state_dict()) info_path = re.sub('.pt$', '.yaml', path) configs = {} if os.path.exists(info_path): diff --git a/wenet/utils/context_graph.py b/wenet/utils/context_graph.py index bb40fa1d8..4ed05395d 100644 --- a/wenet/utils/context_graph.py +++ b/wenet/utils/context_graph.py @@ -1,4 +1,10 @@ +import torch +from torch.nn.utils.rnn import pad_sequence + from wenet.dataset.processor import __tokenize_by_bpe_model +from wenet.transformer.context_module import ContextModule +from wenet.transformer.ctc import CTC + from typing import Dict, List @@ -37,20 +43,22 @@ def tokenize(context_list_path, symbol_table, bpe_model=None): context_list.append(labels) return context_list +def tbbm(sp, context_txt): + return __tokenize_by_bpe_model(sp, context_txt) class ContextGraph: """ Context decoding graph, constructing graph using dict instead of WFST Args: context_list_path(str): context list path bpe_model(str): model for english bpe part - context_score(float): context score for each token + context_graph_score(float): context score for each token """ def __init__(self, context_list_path: str, symbol_table: Dict[str, int], bpe_model: str = None, - context_score: float = 6): - self.context_score = context_score + context_graph_score: float = 2.0): + self.context_graph_score = context_graph_score self.context_list = tokenize(context_list_path, symbol_table, bpe_model) self.graph = {0: {}} @@ -58,6 +66,11 @@ def __init__(self, self.state2token = {} self.back_score = {0: 0.0} self.build_graph(self.context_list) + self.graph_biasing = False + self.deep_biasing = False + self.deep_biasing_score = 1.0 + self.context_filtering = True + self.filter_threshold = -4.0 def build_graph(self, context_list: List[List[int]]): """ Constructing the context decoding graph, add arcs with negative @@ -82,8 +95,8 @@ def build_graph(self, context_list: List[List[int]]): self.graph[now_state][context_token[i]] = self.graph_size now_state = self.graph_size if i != len(context_token) - 1: - self.back_score[now_state] = -(i + - 1) * self.context_score + self.back_score[now_state] = \ + -(i + 1) * self.context_graph_score else: self.back_score[now_state] = 0 self.state2token[now_state] = context_token[i] @@ -95,10 +108,118 @@ def find_next_state(self, now_state: int, token: int): from the starting state to avoid token consumption due to mismatches. """ if token in self.graph[now_state]: - return self.graph[now_state][token], self.context_score + return self.graph[now_state][token], self.context_graph_score back_score = self.back_score[now_state] now_state = 0 if token in self.graph[now_state]: - return self.graph[now_state][ - token], back_score + self.context_score + return self.graph[now_state][token], \ + back_score + self.context_graph_score return 0, back_score + + def get_context_list_tensor(self, context_list: List[List[int]]): + """Add 0 as no-bias in the context list and obtain the tensor + form of the context list + """ + context_list_tensor = [torch.tensor([0], dtype=torch.int32)] + for context_token in context_list: + context_list_tensor.append(torch.tensor(context_token, dtype=torch.int32)) + context_list_lengths = torch.tensor([x.size(0) for x in context_list_tensor], + dtype=torch.int32) + context_list_tensor = pad_sequence(context_list_tensor, + batch_first=True, + padding_value=-1) + return context_list_tensor, context_list_lengths + + def forward_deep_biasing(self, + encoder_out: torch.Tensor, + context_module: ContextModule, + ctc: CTC): + """Apply deep biasing based on encoder output and context list + """ + if self.context_filtering: + ctc_probs = ctc.log_softmax(encoder_out).squeeze(0) + filtered_context_list = self.two_stage_filtering( + self.context_list, ctc_probs) + context_list, context_list_lengths = self. \ + get_context_list_tensor(filtered_context_list) + else: + context_list, context_list_lengths = self. \ + get_context_list_tensor(self.context_list) + context_list = context_list.to(encoder_out.device) + context_emb = context_module. \ + forward_context_emb(context_list, context_list_lengths) + encoder_out, _ = \ + context_module(context_emb, encoder_out, + self.deep_biasing_score, True) + return encoder_out + + def two_stage_filtering(self, + context_list: List[List[int]], + ctc_posterior: torch.Tensor, + filter_window_size: int = 64): + """Calculate PSC and SOC for context phrase filtering, + refer to: https://arxiv.org/abs/2301.06735 + """ + if len(context_list) == 0: + return context_list + + SOC_score = {} + for t in range(1, ctc_posterior.shape[0]): + if t % (filter_window_size // 2) != 0 and t != ctc_posterior.shape[0] - 1: + continue + # calculate PSC + PSC_score = {} + max_posterior, _ = torch.max(ctc_posterior[max(0, + t - filter_window_size):t, :], + dim=0, keepdim=False) + max_posterior = max_posterior.tolist() + for i in range(len(context_list)): + score = sum(max_posterior[j] for j in context_list[i]) \ + / len(context_list[i]) + PSC_score[i] = max(SOC_score.get(i, -float('inf')), score) + PSC_filtered_index = [] + for i in PSC_score: + if PSC_score[i] > self.filter_threshold: + PSC_filtered_index.append(i) + if len(PSC_filtered_index) == 0: + continue + filtered_context_list = [] + for i in PSC_filtered_index: + filtered_context_list.append(context_list[i]) + + # calculate SOC + win_posterior = ctc_posterior[max(0, t - filter_window_size):t, :] + win_posterior = win_posterior.unsqueeze(0) \ + .expand(len(filtered_context_list), -1, -1) + select_win_posterior = [] + for i in range(len(filtered_context_list)): + select_win_posterior.append(torch.index_select( + win_posterior[0], 1, + torch.tensor(filtered_context_list[i], + device=ctc_posterior.device)).transpose(0, 1)) + select_win_posterior = \ + pad_sequence(select_win_posterior, + batch_first=True).transpose(1, 2).contiguous() + dp = torch.full((select_win_posterior.shape[0], + select_win_posterior.shape[2]), + -10000.0, dtype=torch.float32, + device=select_win_posterior.device) + dp[:, 0] = select_win_posterior[:, 0, 0] + for win_t in range(1, select_win_posterior.shape[1]): + temp = dp[:, :-1] + select_win_posterior[:, win_t, 1:] + idx = torch.where(temp > dp[:, 1:]) + idx_ = (idx[0], idx[1] + 1) + dp[idx_] = temp[idx] + dp[:, 0] = \ + torch.where(select_win_posterior[:, win_t, 0] > dp[:, 0], + select_win_posterior[:, win_t, 0], dp[:, 0]) + for i in range(len(filtered_context_list)): + SOC_score[PSC_filtered_index[i]] = \ + max(SOC_score.get(PSC_filtered_index[i], -float('inf')), + dp[i][len(filtered_context_list[i]) - 1] + / len(filtered_context_list[i])) + filtered_context_list = [] + for i in range(len(context_list)): + if SOC_score.get(i, -float('inf')) > self.filter_threshold: + filtered_context_list.append(context_list[i]) + return filtered_context_list diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index a128f6d6c..ad03d1498 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -60,11 +60,14 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, num_seen_utts = 0 with model_context(): for batch_idx, batch in enumerate(data_loader): - key, feats, target, feats_lengths, target_lengths = batch + key, feats, target, feats_lengths, target_lengths, \ + context_data = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) + for i in range(len(context_data)): + context_data[i] = context_data[i].to(device) num_utts = target_lengths.size(0) if num_utts == 0: continue @@ -85,7 +88,7 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, dtype=ds_dtype, cache_enabled=False ): loss_dict = model(feats, feats_lengths, target, - target_lengths) + target_lengths, context_data) loss = loss_dict['loss'] # NOTE(xcsong): Zeroing the gradients is handled automatically by DeepSpeed after the weights # noqa # have been updated using a mini-batch. DeepSpeed also performs gradient averaging automatically # noqa @@ -99,7 +102,7 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, # https://pytorch.org/docs/stable/notes/amp_examples.html with torch.cuda.amp.autocast(scaler is not None): loss_dict = model(feats, feats_lengths, target, - target_lengths) + target_lengths, context_data) loss = loss_dict['loss'] / accum_grad if use_amp: scaler.scale(loss).backward() @@ -170,11 +173,14 @@ def cv(self, model, data_loader, device, args): total_loss = 0.0 with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): - key, feats, target, feats_lengths, target_lengths = batch + key, feats, target, feats_lengths, target_lengths, \ + context_data = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) + for i in range(len(context_data)): + context_data[i] = context_data[i].to(device) num_utts = target_lengths.size(0) if num_utts == 0: continue @@ -183,10 +189,11 @@ def cv(self, model, data_loader, device, args): enabled=ds_dtype is not None, dtype=ds_dtype, cache_enabled=False ): - loss_dict = model(feats, feats_lengths, - target, target_lengths) + loss_dict = model(feats, feats_lengths, target, + target_lengths, context_data) else: - loss_dict = model(feats, feats_lengths, target, target_lengths) + loss_dict = model(feats, feats_lengths, target, + target_lengths, context_data) loss = loss_dict['loss'] if torch.isfinite(loss): num_seen_utts += num_utts diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index f01e03469..7ffa101e8 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -22,6 +22,7 @@ from wenet.transformer.ctc import CTC from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder +from wenet.transformer.context_module import ContextModule from wenet.branchformer.encoder import BranchformerEncoder from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.efficient_conformer.encoder import EfficientConformerEncoder @@ -79,6 +80,13 @@ def init_model(configs): **configs['decoder_conf']) ctc = CTC(vocab_size, encoder.output_size()) + context_module_type = configs.get('context_module', '') + if context_module_type == 'cppn': + context_module = ContextModule(vocab_size, + **configs['context_module_conf']) + else: + context_module = None + # Init joint CTC/Attention or Transducer model if 'predictor' in configs: predictor_type = configs.get('predictor', 'rnn') @@ -108,6 +116,7 @@ def init_model(configs): attention_decoder=decoder, joint=joint, ctc=ctc, + context_module=context_module, **configs['model_conf']) elif 'paraformer' in configs: predictor = Predictor(**configs['cif_predictor_conf']) @@ -122,6 +131,7 @@ def init_model(configs): encoder=encoder, decoder=decoder, ctc=ctc, + context_module=context_module, lfmmi_dir=configs.get('lfmmi_dir', ''), **configs['model_conf']) return model