Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wenet] nn context biasing #1982

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
52 changes: 40 additions & 12 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,21 +160,37 @@ def get_args():
default=0.0,
help='lm scale for hlg attention rescore decode')

parser.add_argument(
'--context_bias_mode',
type=str,
default='',
help='''Context bias mode, selectable from the following
option: decoding-graph、deep-biasing''')
parser.add_argument('--context_bias_mode',
type=str,
default='',
help='''Context bias mode, selectable from the
following option: context_graph,
deep_biasing''')
parser.add_argument('--context_list_path',
type=str,
default='',
help='Context list path')
parser.add_argument('--context_graph_score',
type=float,
default=0.0,
default=2.0,
help='''The higher the score, the greater the degree of
bias using decoding-graph for biasing''')
bias using context_graph for biasing''')
parser.add_argument('--deep_biasing_score',
type=float,
default=1.0,
help='''The higher the score, the greater the degree of
bias using deep_biasing for biasing''')
parser.add_argument('--context_filtering',
action='store_true',
help='''Reduce the size of the context list through
filtering to enhance the effect of context
biasing''')
parser.add_argument('--context_filtering_threshold',
type=float,
default=-4.0,
help='''The threshold for context filtering, the larger
the value, the closer it is to 0, and the fewer
remaining context phrases are filtered''')

args = parser.parse_args()
print(args)
Expand Down Expand Up @@ -225,6 +241,9 @@ def main():
test_conf['batch_conf']['batch_size'] = args.batch_size
non_lang_syms = read_non_lang_symbols(args.non_lang_syms)

if 'context_conf' in test_conf:
del test_conf['context_conf']

test_dataset = Dataset(args.data_type,
args.test_data,
symbol_table,
Expand Down Expand Up @@ -256,33 +275,42 @@ def main():
paraformer_beam_search = None

context_graph = None
if 'decoding-graph' in args.context_bias_mode:
if args.context_bias_mode != '':
context_graph = ContextGraph(args.context_list_path, symbol_table,
args.bpe_model, args.context_graph_score)
context_graph.context_filtering = args.context_filtering
context_graph.filter_threshold = args.context_filtering_threshold
if 'deep_biasing' in args.context_bias_mode:
context_graph.deep_biasing = True
context_graph.deep_biasing_score = args.deep_biasing_score
if 'context_graph' in args.context_bias_mode:
context_graph.graph_biasing = True

with torch.no_grad(), open(args.result_file, 'w') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, feats_lengths, target_lengths = batch
keys, feats, target, feats_lengths, target_lengths, _ = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
target_lengths = target_lengths.to(device)

if args.mode == 'attention':
hyps, _ = model.recognize(
feats,
feats_lengths,
beam_size=args.beam_size,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming)
simulate_streaming=args.simulate_streaming,)
hyps = [hyp.tolist() for hyp in hyps]
elif args.mode == 'ctc_greedy_search':
hyps, _ = model.ctc_greedy_search(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming)
simulate_streaming=args.simulate_streaming,
context_graph=context_graph)
elif args.mode == 'rnnt_greedy_search':
assert (feats.size(0) == 1)
assert 'predictor' in configs
Expand Down
11 changes: 11 additions & 0 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,17 @@ def main():
num_params = sum(p.numel() for p in model.parameters())
print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None # noqa

if 'context_module_conf' in configs:
# Freeze other parts of the model during training context bias module
for p in model.parameters():
p.requires_grad = False
for p in model.context_module.parameters():
p.requires_grad = True
for p in model.context_module.context_decoder_ctc_linear.parameters():
p.requires_grad = False
# Turn off dynamic chunk because it will affect the training of bias
model.encoder.use_dynamic_chunk = False

# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
Expand Down
6 changes: 6 additions & 0 deletions wenet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,11 @@ def Dataset(data_type,

batch_conf = conf.get('batch_conf', {})
dataset = Processor(dataset, processor.batch, **batch_conf)

context_conf = conf.get('context_conf', {})
if len(context_conf) != 0:
dataset = Processor(dataset, processor.context_sampling,
symbol_table, **context_conf)

dataset = Processor(dataset, processor.padding)
return dataset
118 changes: 116 additions & 2 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,99 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000):
logging.fatal('Unsupported batch type {}'.format(batch_type))


def context_sampling(data,
symbol_table,
len_min,
len_max,
utt_num_context,
batch_num_context,
):
"""Perform context sampling by randomly selecting context phrases from the
utterance to obtain a context list for the entire batch

Args:
data: Iterable[List[{key, feat, label}]]

Returns:
Iterable[List[{key, feat, label, context_list}]]
"""
rev_symbol_table = {}
for token in symbol_table:
rev_symbol_table[symbol_table[token]] = token
context_list_over_all = []
for sample in data:
batch_label = [sample[i]['label'] for i in range(len(sample))]
context_list = []
for utt_label in batch_label:
st_index_list = []
for i in range(len(utt_label)):
if '▁' not in symbol_table:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我想请问下,这里如果我的建模单元是中文汉字+英文bpe,这里是不是不太适用,需要改下?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,我自己训练的时候都是纯中文或者纯英文,英文在热词采样的时候对下划线特殊处理了下保证不会采样出半个词的情况,如果同时有中文和英文这部分最好是改下

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢~

st_index_list.append(i)
elif rev_symbol_table[utt_label[i]][0] == '▁':
st_index_list.append(i)
st_index_list.append(len(utt_label))

st_select = []
en_select = []
for _ in range(0, utt_num_context):
random_len = random.randint(min(len(st_index_list) - 1, len_min),
min(len(st_index_list) - 1, len_max))
random_index = random.randint(0, len(st_index_list) -
random_len - 1)
st_index = st_index_list[random_index]
en_index = st_index_list[random_index + random_len]
context_label = utt_label[st_index: en_index]
cross_flag = True
for i in range(len(st_select)):
if st_index >= st_select[i] and st_index < en_select[i]:
cross_flag = False
elif en_index > st_select[i] and en_index <= en_select[i]:
cross_flag = False
elif st_index < st_select[i] and en_index > en_select[i]:
cross_flag = False
if cross_flag:
context_list.append(context_label)
st_select.append(st_index)
en_select.append(en_index)

if len(context_list) > batch_num_context:
context_list_over_all = context_list
elif len(context_list) + len(context_list_over_all) > batch_num_context:
context_list_over_all.extend(context_list)
context_list_over_all = context_list_over_all[-batch_num_context:]
else:
context_list_over_all.extend(context_list)
context_list = context_list_over_all.copy()
context_list.insert(0, [0])
for i in range(len(context_list)):
context_list[i] = torch.tensor(context_list[i], dtype=torch.int32)
sample[0]['context_list'] = context_list
yield sample


def context_label_generate(label, context_list):
""" Generate context labels corresponding to the utterances based on
the context list
"""
context_labels = []
for x in label:
cur_len = len(x)
context_label = []
count = 0
for i in range(cur_len):
for j in range(1, len(context_list)):
if i + len(context_list[j]) > cur_len:
continue
if x[i:i + len(context_list[j])].equal(context_list[j]):
count = max(count, len(context_list[j]))
if count > 0:
context_label.append(x[i])
count -= 1
context_label = torch.tensor(context_label, dtype=torch.int64)
context_labels.append(context_label)
return context_labels


def padding(data):
""" Padding the data into training data

Expand Down Expand Up @@ -641,5 +734,26 @@ def padding(data):
batch_first=True,
padding_value=-1)

yield (sorted_keys, padded_feats, padding_labels, feats_lengths,
label_lengths)
if 'context_list' not in sample[0]:
yield (sorted_keys, padded_feats, padding_labels, feats_lengths,
label_lengths, [])
else:
context_lists = sample[0]['context_list']
context_list_lengths = \
torch.tensor([x.size(0) for x in context_lists], dtype=torch.int32)
padding_context_lists = pad_sequence(context_lists,
batch_first=True,
padding_value=-1)

sorted_context_labels = context_label_generate(sorted_labels,
context_lists)
context_label_lengths = \
torch.tensor([x.size(0) for x in sorted_context_labels],
dtype=torch.int32)
padding_context_labels = pad_sequence(sorted_context_labels,
batch_first=True,
padding_value=-1)
yield (sorted_keys, padded_feats, padding_labels,
feats_lengths, label_lengths,
[padding_context_lists, padding_context_labels,
context_list_lengths, context_label_lengths])
34 changes: 32 additions & 2 deletions wenet/transducer/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss
from wenet.transformer.context_module import ContextModule
from wenet.utils.common import (IGNORE_ID, add_blank, add_sos_eos,
reverse_pad_list)

Expand All @@ -35,6 +36,8 @@ def __init__(
BiTransformerDecoder]] = None,
ctc: Optional[CTC] = None,
ctc_weight: float = 0,
context_module: ContextModule = None,
bias_weight: float = 0,
ignore_id: int = IGNORE_ID,
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
Expand All @@ -50,8 +53,8 @@ def __init__(
) -> None:
assert attention_weight + ctc_weight + transducer_weight == 1.0
super().__init__(vocab_size, encoder, attention_decoder, ctc,
ctc_weight, ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
context_module, ctc_weight, ignore_id, reverse_weight,
lsm_weight, length_normalized_loss, '', bias_weight)

self.blank = blank
self.transducer_weight = transducer_weight
Expand Down Expand Up @@ -93,6 +96,7 @@ def forward(
text: torch.Tensor,
text_lengths: torch.Tensor,
steps: int = 0,
context_data: List[torch.Tensor] = None,
) -> Dict[str, Optional[torch.Tensor]]:
"""Frontend + Encoder + predictor + joint + loss

Expand All @@ -101,6 +105,8 @@ def forward(
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
context_data: [context_list, context_label,
context_list_lengths, context_label_lengths]
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
Expand All @@ -112,6 +118,27 @@ def forward(
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)

# Context biasing branch
loss_bias: Optional[torch.Tensor] = None
if self.context_module is not None:
assert len(context_data) == 4
context_list = context_data[0]
context_label = context_data[1]
context_list_lengths = context_data[2]
context_label_lengths = context_data[3]
context_emb = self.context_module. \
forward_context_emb(context_list, context_list_lengths)
encoder_out, bias_out = self.context_module(context_emb,
encoder_out)
bias_out = bias_out.transpose(0, 1).log_softmax(2)
loss_bias = self.context_module.bias_loss(bias_out,
context_label,
encoder_out_lens,
context_label_lengths
) / bias_out.size(1)
else:
loss_bias = None

# compute_loss
loss_rnnt = compute_loss(
self,
Expand Down Expand Up @@ -143,12 +170,15 @@ def forward(
loss = loss + self.ctc_weight * loss_ctc.sum()
if loss_att is not None:
loss = loss + self.attention_decoder_weight * loss_att.sum()
if loss_bias is not None:
loss = loss + self.bias_weight * loss_bias.sum()
# NOTE: 'loss' must be in dict
return {
'loss': loss,
'loss_att': loss_att,
'loss_ctc': loss_ctc,
'loss_rnnt': loss_rnnt,
'loss_bias': loss_bias,
}

def init_bs(self):
Expand Down
Loading