From 12499282bb0476b9a63589d08aecb234680d9bb2 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 12 Aug 2024 14:42:30 +0800 Subject: [PATCH 1/4] [sensevoice] support sensevoice small arch --- wenet/sensevoice/sensevoice_small_model.py | 258 +++++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 wenet/sensevoice/sensevoice_small_model.py diff --git a/wenet/sensevoice/sensevoice_small_model.py b/wenet/sensevoice/sensevoice_small_model.py new file mode 100644 index 000000000..2b12c69c8 --- /dev/null +++ b/wenet/sensevoice/sensevoice_small_model.py @@ -0,0 +1,258 @@ +from typing import Dict, List, Optional, Tuple + +import torch +import torch.utils.checkpoint as ckpt +from wenet.paraformer.attention import MultiHeadedAttentionSANM +from wenet.paraformer.layers import LFR, AliParaformerEncoderLayer, SanmEncoder +from wenet.transformer.asr_model import ASRModel +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import TransformerDecoder +from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.transformer.search import DecodeResult +from wenet.utils.common import IGNORE_ID, mask_to_bias +from wenet.utils.context_graph import ContextGraph +from wenet.utils.mask import add_optional_chunk_mask, make_pad_mask + + +class SanmEncoderWithTp(SanmEncoder): + + def __init__(self, + tp_num_blocks: int, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + kernel_size: int = 11, + sanm_shfit: int = 0, + gradient_checkpointing: bool = False): + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, kernel_size, sanm_shfit, + gradient_checkpointing) + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + encoder_selfattn_layer = MultiHeadedAttentionSANM + self.tp_encoders = torch.nn.ModuleList([ + AliParaformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + PositionwiseFeedForward( + output_size, + linear_units, + dropout_rate, + ), + dropout_rate, + normalize_before, + in_size=output_size) for _ in range(tp_num_blocks - 1) + ]) + self.tp_norm = torch.nn.LayerNorm(output_size) + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1 + ) -> Tuple[torch.Tensor, torch.Tensor]: + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask( + xs, + masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks, + # Since we allow up to 1s(100 frames) delay, the maximum + # chunk_size is 100 / 4 = 25. + max_chunk_size=int(100.0 / self.embed.subsampling_rate)) + if self.use_sdpa: + chunk_masks = mask_to_bias(chunk_masks, xs.dtype) + if self.gradient_checkpointing and self.training: + xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, + mask_pad) + else: + xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + + # sensevoice tp encoders: + if self.gradient_checkpointing and self.training: + xs = self.forward_tp_layers_checkpointed(xs, chunk_masks, pos_emb, + mask_pad) + else: + xs = self.forward_tp_layers(xs, chunk_masks, pos_emb, mask_pad) + xs = self.tp_norm(xs) + return xs, masks + + @torch.jit.unused + def forward_tp_layers_checkpointed(self, xs: torch.Tensor, + chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.tp_encoders: + xs, _, _, _, _ = ckpt.checkpoint( + layer.__call__, + xs, + chunk_masks, + pos_emb, + mask_pad, + ) + return xs + + def forward_tp_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.tp_encoders: + xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs + + +class SenseVoiceSmall(ASRModel): + + def __init__(self, + input_dim: int, + vocab_size: int, + encoder: SanmEncoderWithTp, + decoder: TransformerDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + ignore_id: int = IGNORE_ID, + reverse_weight: float = 0, + lsm_weight: float = 0, + length_normalized_loss: bool = False, + special_tokens: Optional[dict] = None, + apply_non_blank_embedding: bool = False): + super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, + ignore_id, reverse_weight, lsm_weight, + length_normalized_loss, special_tokens, + apply_non_blank_embedding) + + assert ctc_weight != 0.0 + assert special_tokens is not None + self.encoder = encoder + self.decoder = decoder + self.lfr = LFR() + + self.sos = special_tokens['sos'] + self.eos = special_tokens['eos'] + + self.lid_dict = special_tokens['lid'] + self.itn_dict = special_tokens['itn'] + self.emo_dict = special_tokens['emo'] + self.embed = torch.nn.Embedding( + 7 + len(self.lid_dict) + len(self.emo_dict), input_dim) + + assert self.encoder.global_cmvn is not None + self.global_cmvn = self.encoder.global_cmvn + self.encoder.global_cmvn = None + + self.criterion_context = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + @torch.jit.unused + def forward(self, batch: dict, + device: torch.device) -> Dict[str, Optional[torch.Tensor]]: + speech = batch['feats'].to(device) + speech_lengths = batch['feats_lengths'].to(device) + text = batch['target'].to(device) + text_lengths = batch['target_lengths'].to(device) + + speech, speech_lengths = self.lfr(speech, speech_lengths) + speech = self.global_cmvn = self.global_cmvn(speech) + + # context pattern: + # lid emo event tn speech + # TODO: move to dataset + lid = batch['lid'].to(device).unsqueeze(1) # [B,1] + itn = batch['itn'].to(device).unsqueeze(1) # [B,1] + emo = batch['emo'].to(device).unsqueeze(1) # [B,1] + event = batch['enent'].to(device).unsqueeze(1) # [B,1] + context = torch.stack([lid, itn, event, emo], dim=1) + + context_embed = self.embed(context) # [B,3,D] + speech = torch.cat((context_embed, speech), dim=1) + speech_lengths = speech_lengths + 3 + 1 + + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + encoder_out_lens = encoder_mask.sum(-1).squeeze() + loss_ctc_speech = self.ctc(encoder_out[:4:, :, :], + encoder_out_lens - 4, text[:, 4:], + text_lengths - 4) + + context_logits = self.ctc.ctc_lo(encoder_out[:, :4, :]) + loss_context = self.criterion_context(context_logits, text[:, :4]) + + loss_att, acc_att = None, 0 + if self.ctc_weight != 1.0: + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + + loss_ctc = loss_ctc_speech + loss_context + loss = loss_ctc + if loss_att is not None: + loss = self.ctc_weight * loss_ctc + (1 - + self.ctc_weight) * loss_att + + # TODO: log context acc + return { + "loss": loss, + "loss_att": loss_att, + "loss_ctc": loss_ctc, + "loss_ctc_speech": loss_ctc_speech, + "loss_context": loss_context, + "th_accuracy": acc_att, + } + + def decode( + self, + methods: List[str], + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + ctc_weight: float = 0, + simulate_streaming: bool = False, + reverse_weight: float = 0, + context_graph: ContextGraph = None, + blank_id: int = 0, + blank_penalty: float = 0, + length_penalty: float = 0, + infos: Dict[str, + List[str]] = None) -> Dict[str, List[DecodeResult]]: + + pass From 48a30eb460698bdcc03ee6f53f299605fe2e902c Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 12 Aug 2024 17:26:15 +0800 Subject: [PATCH 2/4] support pure sentencepice tokenizer --- wenet/sensevoice/sensevoice_small_model.py | 2 +- wenet/text/sentencepiece_tokenizer.py | 55 ++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 wenet/text/sentencepiece_tokenizer.py diff --git a/wenet/sensevoice/sensevoice_small_model.py b/wenet/sensevoice/sensevoice_small_model.py index 2b12c69c8..69411c115 100644 --- a/wenet/sensevoice/sensevoice_small_model.py +++ b/wenet/sensevoice/sensevoice_small_model.py @@ -203,7 +203,7 @@ def forward(self, batch: dict, event = batch['enent'].to(device).unsqueeze(1) # [B,1] context = torch.stack([lid, itn, event, emo], dim=1) - context_embed = self.embed(context) # [B,3,D] + context_embed = self.embed(context) # [B,4,D] speech = torch.cat((context_embed, speech), dim=1) speech_lengths = speech_lengths + 3 + 1 diff --git a/wenet/text/sentencepiece_tokenizer.py b/wenet/text/sentencepiece_tokenizer.py new file mode 100644 index 000000000..c4e8a548a --- /dev/null +++ b/wenet/text/sentencepiece_tokenizer.py @@ -0,0 +1,55 @@ +from os import PathLike +from typing import Dict, List, Union +from wenet.text.base_tokenizer import (BaseTokenizer, T) + + +class SentencepieceTokenizer(BaseTokenizer): + """ Sentencepiece Tokenizer + """ + + def __init__( + self, + model_path: Union[PathLike, str], + ) -> None: + super().__init__() + + self.model_path = model_path + self.model = None + self._vocab_size = None + self._symbol_table = None + + def _build_sp(self): + if self.model is None: + import sentencepiece as spm + self.model = spm.SentencePieceProcessor() + self.model.load(self.model_path) + self._symbol_table = { + _id: self.model.id_to_piece(_id) + for _id in range(self.model.get_piece_size()) + } + self.vocab_size = len(self._symbol_table) + + def text2tokens(self, line: str) -> List[T]: + self._build_sp() + return self.model.encode_as_pieces(line) + + def tokens2ids(self, tokens: List[T]) -> List[int]: + self._build_sp() + return self.model.piece_to_id(tokens) + + def ids2tokens(self, ids: List[int]) -> List[T]: + self._build_sp() + return self.model.id_to_piece(ids) + + def tokens2text(self, tokens: List[T]) -> str: + self._build_sp() + return self.model.decode(tokens) + + @property + def symbol_table(self) -> Dict[T, int]: + self._build_sp() + return self._symbol_table + + def vocab_size(self) -> int: + self._build_sp() + return self.vocab_size From 53c52adf8e5cb49933ef08d2e57161551e6fec4d Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 12 Aug 2024 20:34:49 +0800 Subject: [PATCH 3/4] convert script works --- ...nsevoice_small_to_wenet_config_and_ckpt.py | 170 ++++++++++++++++++ wenet/text/sentencepiece_tokenizer.py | 3 +- wenet/utils/init_tokenizer.py | 4 + 3 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 wenet/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py diff --git a/wenet/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py b/wenet/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py new file mode 100644 index 000000000..c6e6d54df --- /dev/null +++ b/wenet/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py @@ -0,0 +1,170 @@ +# NOTE(Mddct): This file is to convert paraformer config to wenet's train.yaml config + +import argparse +import os +from typing import Dict +import torch +import copy +from wenet.paraformer.convert_paraformer_to_wenet_config_and_ckpt import ( + _filter_dict_fields) + +import yaml + +from wenet.paraformer.convert_paraformer_to_wenet_config_and_ckpt import ( + convert_to_wenet_json_cmvn) +from wenet.text.sentencepiece_tokenizer import SentencepieceTokenizer + + +def convert_to_wenet_yaml(configs, wenet_yaml_path: str, unit_path: str, + tokenizer: SentencepieceTokenizer, + tokenizer_path) -> Dict: + configs = copy.deepcopy(configs) + configs['encoder'] = 'sanm_encoder_with_tp' + configs['encoder_conf']['input_layer'] = 'paraformer_dummy' + configs['lfr_conf'] = {'lfr_m': 7, 'lfr_n': 6} + + configs['input_dim'] = configs['lfr_conf']['lfr_m'] * 80 + # This type not use + del configs['encoder_conf']['selfattention_layer_type'], configs[ + 'encoder_conf']['pos_enc_class'] + configs['encoder_conf']['pos_enc_layer_type'] = 'abs_pos_paraformer' + + configs['ctc_conf'] = {} + configs['ctc_conf']['ctc_blank_id'] = 0 + + configs['tokenizer'] = 'tokenizer' + configs['tokenizer_conf'] = {} + configs['tokenizer_conf']['model_path'] = tokenizer_path + configs['tokenizer_conf']['special_tokens'] = {} + + with open(unit_path, 'w') as f: + for token, i in tokenizer.symbol_table.items(): + f.write("{} {}\n".format(token, i)) + + configs['tokenizer_conf']['special_tokens'][''] = 2 + configs['tokenizer_conf']['special_tokens'][''] = 1 + configs['tokenizer_conf']['special_tokens'][''] = 0 + configs['tokenizer_conf']['special_tokens'][''] = 0 + + configs['dataset'] = 'asr_dataset' + configs['dataset_conf'] = {} + configs['dataset_conf']['filter_conf'] = {} + configs['dataset_conf']['filter_conf']['max_length'] = 20000 + configs['dataset_conf']['filter_conf']['min_length'] = 0 + configs['dataset_conf']['filter_conf']['token_max_length'] = 200 + configs['dataset_conf']['filter_conf']['token_min_length'] = 1 + configs['dataset_conf']['resample_conf'] = {} + configs['dataset_conf']['resample_conf']['resample_rate'] = 16000 + configs['dataset_conf']['speed_perturb'] = True + configs['dataset_conf']['spec_aug'] = True + configs['dataset_conf']['spec_aug_conf'] = {} + configs['dataset_conf']['spec_aug_conf']['num_t_mask'] = 2 + configs['dataset_conf']['spec_aug_conf']['num_f_mask'] = 2 + configs['dataset_conf']['spec_aug_conf']['max_t'] = 50 + configs['dataset_conf']['spec_aug_conf']['max_f'] = 10 + configs['dataset_conf']['fbank_conf'] = {} + configs['dataset_conf']['fbank_conf']['num_mel_bins'] = 80 + configs['dataset_conf']['fbank_conf']['frame_shift'] = 10 + configs['dataset_conf']['fbank_conf']['frame_length'] = 25 + configs['dataset_conf']['fbank_conf']['dither'] = 0.1 + configs['dataset_conf']['fbank_conf']['window_type'] = 'hamming' + configs['dataset_conf']['spec_sub'] = False + configs['dataset_conf']['spec_trim'] = False + configs['dataset_conf']['shuffle'] = True + configs['dataset_conf']['shuffle_conf'] = {} + configs['dataset_conf']['shuffle_conf']['shuffle_size'] = 1500 + configs['dataset_conf']['sort'] = True + configs['dataset_conf']['sort_conf'] = {} + configs['dataset_conf']['sort_conf']['sort_size'] = 500 + configs['dataset_conf']['batch_conf'] = {} + configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic' + configs['dataset_conf']['batch_conf']['batch_size'] = 26 + configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000 + + configs['grad_clip'] = 5 + configs['accum_grad'] = 1 + configs['max_epoch'] = 100 + configs['log_interval'] = 100 + + configs['model_conf'] = {} + configs['model_conf']['length_normalized_loss'] = False + configs['model_conf']['ctc_weight'] = 1.0 + configs['model_conf']['lsm_weight'] = 0.1 + + with open(wenet_yaml_path, '+w') as f: + f.write(yaml.dump(configs)) + f.flush() + return configs + + +def convert_to_wenet_state_dict(args, wenet_model_path): + checkpoint = torch.load(args.sensevoice_model, map_location='cpu') + torch.save(checkpoint, wenet_model_path) + + +def get_args(): + parser = argparse.ArgumentParser(description='load ali-sensevoice') + parser.add_argument('--sensevoice_config', + default=None, + help='ali released SenseVoice model\'s config') + parser.add_argument('--sensevoice_cmvn', + default=None, + help='ali released SenseVoice model\'s cmvn') + parser.add_argument( + '--sensevoice_spm', + default=None, + help='ali released sentencepiece tokenizer\'s model path') + parser.add_argument('--sensevoice_model', + default=None, + help='ali released sentencepiece model path') + + parser.add_argument('--output_dir', + default='.', + help="output file:\ + global_cmvn, units.txt, train.yaml, wenet_sensevoice_small.pt") + args = parser.parse_args() + return args + + +def main(): + + args = get_args() + assert os.path.exists(args.output_dir) + with open(args.sensevoice_config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + filter_to_keep = { + "encoder", + "encoder_conf", + } + configs = _filter_dict_fields(configs, filter_to_keep) + + json_cmvn_path = os.path.join(args.output_dir, 'global_cmvn') + convert_to_wenet_json_cmvn(args.sensevoice_cmvn, json_cmvn_path) + + wenet_units = os.path.join(args.output_dir, 'units.txt') + tokenizer = SentencepieceTokenizer(args.sensevoice_spm) + + vocab_size = tokenizer.vocab_size() + configs['output_dim'] = vocab_size + configs['model'] = 'sensevoice_small' + configs['cmvn'] = "global_cmvn" + configs['cmvn_conf'] = {} + configs['cmvn_conf']['is_json_cmvn'] = True + configs['cmvn_conf']['cmvn_file'] = json_cmvn_path + wenet_train_yaml = os.path.join(args.output_dir, "train.yaml") + convert_to_wenet_yaml(configs, wenet_train_yaml, wenet_units, tokenizer, + args.sensevoice_spm) + wenet_model_path = os.path.join(args.output_dir, + "wenet_sensevoice_small.pt") + convert_to_wenet_state_dict(args, wenet_model_path) + + print("Please check {} {} {} {} in {}".format(json_cmvn_path, + wenet_train_yaml, + wenet_model_path, + wenet_units, + args.output_dir)) + + +if __name__ == "__main__": + + main() diff --git a/wenet/text/sentencepiece_tokenizer.py b/wenet/text/sentencepiece_tokenizer.py index c4e8a548a..9a6f12c66 100644 --- a/wenet/text/sentencepiece_tokenizer.py +++ b/wenet/text/sentencepiece_tokenizer.py @@ -10,6 +10,7 @@ class SentencepieceTokenizer(BaseTokenizer): def __init__( self, model_path: Union[PathLike, str], + **kwargs, ) -> None: super().__init__() @@ -24,7 +25,7 @@ def _build_sp(self): self.model = spm.SentencePieceProcessor() self.model.load(self.model_path) self._symbol_table = { - _id: self.model.id_to_piece(_id) + self.model.id_to_piece(_id): _id for _id in range(self.model.get_piece_size()) } self.vocab_size = len(self._symbol_table) diff --git a/wenet/utils/init_tokenizer.py b/wenet/utils/init_tokenizer.py index c0c2ce7d7..e1e347fb7 100644 --- a/wenet/utils/init_tokenizer.py +++ b/wenet/utils/init_tokenizer.py @@ -19,6 +19,7 @@ from wenet.text.bpe_tokenizer import BpeTokenizer from wenet.text.char_tokenizer import CharTokenizer from wenet.text.paraformer_tokenizer import ParaformerTokenizer +from wenet.text.sentencepiece_tokenizer import SentencepieceTokenizer from wenet.text.whisper_tokenizer import WhisperTokenizer @@ -47,6 +48,9 @@ def init_tokenizer(configs) -> BaseTokenizer: tokenizer = ParaformerTokenizer( symbol_table=configs['tokenizer_conf']['symbol_table_path'], seg_dict=configs['tokenizer_conf']['seg_dict_path']) + elif tokenizer_type == 'sentencepiece': + tokenizer = SentencepieceTokenizer( + model_path=configs['tokenizer_conf']['model_path']) else: raise NotImplementedError logging.info("use {} tokenizer".format(configs["tokenizer"])) From 405e2576dd9ada047ff0937cc2d21798d1e9d8f0 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 12 Aug 2024 21:28:29 +0800 Subject: [PATCH 4/4] init model works --- ...nsevoice_small_to_wenet_config_and_ckpt.py | 2 ++ wenet/sensevoice/sensevoice_small_model.py | 20 +++++++++---------- wenet/utils/init_model.py | 10 +++++++--- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/wenet/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py b/wenet/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py index c6e6d54df..62bec27cf 100644 --- a/wenet/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py +++ b/wenet/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py @@ -23,6 +23,8 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str, unit_path: str, configs['encoder_conf']['input_layer'] = 'paraformer_dummy' configs['lfr_conf'] = {'lfr_m': 7, 'lfr_n': 6} + configs['decoder'] = None + configs['input_dim'] = configs['lfr_conf']['lfr_m'] * 80 # This type not use del configs['encoder_conf']['selfattention_layer_type'], configs[ diff --git a/wenet/sensevoice/sensevoice_small_model.py b/wenet/sensevoice/sensevoice_small_model.py index 69411c115..b4996b782 100644 --- a/wenet/sensevoice/sensevoice_small_model.py +++ b/wenet/sensevoice/sensevoice_small_model.py @@ -18,8 +18,8 @@ class SanmEncoderWithTp(SanmEncoder): def __init__(self, - tp_num_blocks: int, input_size: int, + tp_blocks: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, @@ -64,7 +64,7 @@ def __init__(self, ), dropout_rate, normalize_before, - in_size=output_size) for _ in range(tp_num_blocks - 1) + in_size=output_size) for _ in range(tp_blocks) ]) self.tp_norm = torch.nn.LayerNorm(output_size) @@ -140,7 +140,6 @@ def forward_tp_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, class SenseVoiceSmall(ASRModel): def __init__(self, - input_dim: int, vocab_size: int, encoder: SanmEncoderWithTp, decoder: TransformerDecoder, @@ -163,14 +162,11 @@ def __init__(self, self.decoder = decoder self.lfr = LFR() - self.sos = special_tokens['sos'] - self.eos = special_tokens['eos'] + self.sos = special_tokens[''] + self.eos = special_tokens[''] - self.lid_dict = special_tokens['lid'] - self.itn_dict = special_tokens['itn'] - self.emo_dict = special_tokens['emo'] - self.embed = torch.nn.Embedding( - 7 + len(self.lid_dict) + len(self.emo_dict), input_dim) + # hard code for sensevoice small + self.embed = torch.nn.Embedding(7 + 7 + 2, 560) assert self.encoder.global_cmvn is not None self.global_cmvn = self.encoder.global_cmvn @@ -183,6 +179,10 @@ def __init__(self, normalize_length=length_normalized_loss, ) + @torch.jit.unused + def tie_or_clone_weights(self, jit_mode: bool = True): + pass + @torch.jit.unused def forward(self, batch: dict, device: torch.device) -> Dict[str, Optional[torch.Tensor]]: diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 3c2e36abf..2495c60ce 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -23,6 +23,7 @@ from wenet.paraformer.paraformer import Paraformer, Predictor from wenet.LLM.causallm_model import CausalLM from wenet.LLM.decoder import DecoderOnly +from wenet.sensevoice.sensevoice_small_model import SanmEncoderWithTp, SenseVoiceSmall from wenet.ssl.init_model import WENET_SSL_MODEL_CLASS from wenet.transducer.joint import TransducerJoint from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, @@ -53,6 +54,7 @@ "dual_transformer": DualTransformerEncoder, "dual_conformer": DualConformerEncoder, 'sanm_encoder': SanmEncoder, + 'sanm_encoder_with_tp': SanmEncoderWithTp, } WENET_DECODER_CLASSES = { @@ -84,6 +86,7 @@ "k2_model": K2Model, "transducer": Transducer, 'paraformer': Paraformer, + "sensevoice_small": SenseVoiceSmall, 'causal_llm': CausalLM, } @@ -113,9 +116,10 @@ def init_speech_model(args, configs): **configs['encoder_conf']['efficient_conf'] if 'efficient_conf' in configs['encoder_conf'] else {}) - decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, - encoder.output_size(), - **configs['decoder_conf']) + decoder = None + if decoder_type is not None: + decoder = WENET_DECODER_CLASSES[decoder_type]( + vocab_size, encoder.output_size(), **configs['decoder_conf']) ctc = WENET_CTC_CLASSES[ctc_type]( vocab_size,