diff --git a/wenet/LLM/__init__.py b/wenet/LLM/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/wenet/LLM/causallm_model.py b/wenet/LLM/causallm_model.py deleted file mode 100644 index adf37d71d..000000000 --- a/wenet/LLM/causallm_model.py +++ /dev/null @@ -1,207 +0,0 @@ -from typing import Dict, List, Optional, Union -import torch -from wenet.LLM.decoder import DecoderOnly -from wenet.LLM.sampler import sampler -from wenet.utils.common import IGNORE_ID, th_accuracy -from wenet.utils.mask import make_pad_mask, subsequent_mask - - -class CausalLM(torch.nn.Module): - - def __init__( - self, - vocab_size: int, - decoder: DecoderOnly, - special_tokens: dict, - tie_word_embedding: bool = False, - linear_bias: bool = False, - ignore_id: int = IGNORE_ID, - lsm_weight: float = 0.0, - reduction: str = 'mean', - ) -> None: - super().__init__() - del special_tokens - - self.embed = torch.nn.Embedding(vocab_size, decoder.hidden_size) - self.out = torch.nn.Linear(decoder.hidden_size, - vocab_size, - bias=linear_bias) - - self.decoder = decoder - self.vocab_size = vocab_size - self.criterion_att = torch.nn.CrossEntropyLoss( - ignore_index=ignore_id, - label_smoothing=lsm_weight, - reduction=reduction, - ) - self.tie_word_embedding = tie_word_embedding - self.ignore_id = ignore_id - - @torch.jit.unused - def forward( - self, - batch: dict, - device: torch.device, - ) -> Dict[str, Optional[torch.Tensor]]: - """ Forward for training - """ - text = batch['feats'].to(device) - target = batch['target'].to(device) - text_length = batch['feats_lengths'].to(device) - - mask = ~make_pad_mask(text_length, max_len=text.size(1)).unsqueeze( - 1) # (B,1,L) - causal_mask = subsequent_mask( - mask.size(-1), device=mask.device).unsqueeze(0) # (1,L,L) - att_mask = causal_mask & mask # (B, L, L) - - embeding = self.embed(text) - decoder_out = self.out(self.decoder(embeding, - att_mask)[0]) # (B, L, vocab_size) - loss = self.criterion_att(decoder_out.view(-1, self.vocab_size), - target.view(-1)) - acc = th_accuracy(decoder_out.view(-1, self.vocab_size), - target, - ignore_label=self.ignore_id) - - return { - "loss": loss, - "ppl": torch.exp(loss.detach()), - "th_accuracy": acc - } - - def tie_or_clone_weights(self, jit_mode: bool): - if not self.tie_word_embedding: - return - if jit_mode: - self.out.weight = torch.nn.Parameter(self.embed.weight.clone()) - else: - self.out.weight = self.embed.weight - # TODO(Mddct): whether to deal bias for other llm model - - @torch.jit.unused - @torch.inference_mode() - def generate( - self, - prompts_tokens: List[List[int]], - device: torch.device, - stop_tokens: List[int], - dtype: torch.dtype = torch.float32, - output_len: int = 100, - temperature: Union[float, None] = 0.95, - top_p: float = 1.0, - top_k: int = 100, - ) -> List[List[int]]: - """Generates responses for given prompts using Gemma model.""" - # If a single prompt is provided, treat it as a batch of 1. - batch_size = len(prompts_tokens) - min_prompt_len = min(len(p) for p in prompts_tokens) - max_prompt_len = max(len(p) for p in prompts_tokens) - max_seq_len = max_prompt_len + output_len - assert max_seq_len <= self.decoder.pos_enc.max_len - - # build KV caches - kv_caches = [] - for _ in range(len(self.decoder.decoders)): - size = (batch_size, 0, self.decoder.n_kv_head, - self.decoder.head_dim) - k_cache = torch.zeros(size=size, dtype=dtype, device=device) - v_cache = torch.zeros(size=size, dtype=dtype, device=device) - kv_caches.append((k_cache, v_cache)) - - # prepare inputs - token_ids_tensor = torch.full((batch_size, max_seq_len), - IGNORE_ID, - dtype=torch.int64, - device=device) - input_token_ids_tensor = torch.full((batch_size, min_prompt_len), - IGNORE_ID, - dtype=torch.int64, - device=device) - # right padding - for i, p in enumerate(prompts_tokens): - token_ids_tensor[i, :len(p)] = torch.tensor(p) - input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( - p[:min_prompt_len]) - - prompt_mask_tensor = token_ids_tensor != IGNORE_ID - input_positions_tensor = torch.arange(0, - min_prompt_len, - dtype=torch.int64).to(device) - mask_tensor = torch.ones((1, 1, max_seq_len, max_seq_len), - dtype=torch.bool) - mask_tensor = torch.tril(mask_tensor).to(device) - curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) - att_mask = curr_mask_tensor.squeeze( - 1)[:, :min_prompt_len, :min_prompt_len] - output_positions_tensor = torch.LongTensor([min_prompt_len - 1 - ]).to(device) - temperatures_tensor = None if not temperature else torch.FloatTensor( - [temperature] * batch_size).to(device) - top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) - top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) - output_index = torch.tensor(min_prompt_len, - dtype=torch.int64).to(device) - - input_token_embeding = self.embed(input_token_ids_tensor) - offset = torch.tensor([0] * len(prompts_tokens)).to(device) - input_offset = offset - - stop_tokens_tensor = torch.tensor(stop_tokens, device=device) - # Prefill up to min_prompt_len tokens, then treat other prefill as - # decode and ignore output. - for i in range(max_seq_len - min_prompt_len): - decoder_out, kv_caches, = self.decoder( - input_token_embeding, - att_mask, - input_offset, - kv_caches, - ) - decoder_out = self.out(decoder_out) - decoder_out = decoder_out.index_select(1, output_positions_tensor) - next_token_ids = sampler( - decoder_out, - temperatures_tensor, - top_ps_tensor, - top_ks_tensor, - ) - curr_prompt_mask = prompt_mask_tensor.index_select( - 1, output_index).squeeze(dim=1) - curr_token_ids = token_ids_tensor.index_select( - 1, output_index).squeeze(dim=1) - output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, - next_token_ids).unsqueeze(dim=1) - token_ids_tensor.index_copy_(1, output_index, output_token_ids) - - input_token_ids_tensor = output_token_ids - input_token_embeding = self.embed(input_token_ids_tensor) - - input_positions_tensor = output_index.unsqueeze(dim=-1) - curr_mask_tensor = mask_tensor.index_select( - 2, input_positions_tensor) - att_mask = curr_mask_tensor.squeeze(1)[:, :output_index + - 1, :output_index + 1] - - output_positions_tensor = torch.tensor( - 0, dtype=torch.int64).to(device) - input_offset = offset + output_index.unsqueeze(-1) - output_index = output_index + 1 - - if all(torch.isin(next_token_ids, stop_tokens_tensor)): - break - - token_ids = token_ids_tensor.tolist() - results = [] - for i, tokens in enumerate(token_ids): - trimmed_output = tokens[len(prompts_tokens[i] - ):len(prompts_tokens[i]) + output_len] - for stop_token in stop_tokens: - try: - eos_index = trimmed_output.index(stop_token) - trimmed_output = trimmed_output[:eos_index] - break - except Exception: - continue - results.append(trimmed_output) - - return results diff --git a/wenet/LLM/decoder.py b/wenet/LLM/decoder.py deleted file mode 100644 index b25ee75dd..000000000 --- a/wenet/LLM/decoder.py +++ /dev/null @@ -1,161 +0,0 @@ -from functools import partial -from typing import List, Optional, Tuple, Union -import torch -import torch.utils.checkpoint as ckpt -from wenet.transformer.attention import T_CACHE - -from wenet.transformer.encoder_layer import TransformerEncoderLayer -from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, - WENET_ATTENTION_CLASSES, - WENET_EMB_CLASSES, WENET_MLP_CLASSES, - WENET_NORM_CLASSES) -from wenet.utils.common import mask_to_bias - - -class DecoderOnly(torch.nn.Module): - - def __init__( - self, - n_kv_head: int, - head_dim: int, - hidden_size: int, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - normalize_before: bool = True, - query_bias: bool = False, - key_bias: bool = False, - value_bias: bool = False, - mlp_bias: bool = False, - activation_type: str = "gelu", - gelu_approximate: Union[str, None] = None, - max_position_embeding: int = 8192, - mlp_type: str = 'gated', - layer_norm_type: str = 'rms_norm', - norm_eps: float = 1e-5, - rms_norm_offset: bool = True, - selfattention_layer_type: str = "rope_abs_selfattn", - use_sdpa: bool = False, - gradient_checkpointing: bool = False, - rope_theta: float = 10000.0, - rope_style: str = 'google', - scale_embed: bool = True, - ) -> None: - super().__init__() - - assert selfattention_layer_type in ['rope_abs_selfattn'] - self.pos_enc = WENET_EMB_CLASSES["rope_pos"]( - hidden_size, - head_dim, - max_len=max_position_embeding, - dropout_rate=positional_dropout_rate, - rope_theta=rope_theta, - scale=scale_embed) - if activation_type == "gelu" and gelu_approximate is not None: - activation = WENET_ACTIVATION_CLASSES['gelu']( - approximate=gelu_approximate) - else: - activation = WENET_ACTIVATION_CLASSES[activation_type]() - - mlp_class = WENET_MLP_CLASSES[mlp_type] - self.num_blocks = num_blocks - # TODO: support lora & refactor lora - self.decoders = torch.nn.ModuleList([ - TransformerEncoderLayer( - hidden_size, - WENET_ATTENTION_CLASSES[selfattention_layer_type]( - attention_heads, - hidden_size, - attention_dropout_rate, - query_bias, - key_bias, - value_bias, - use_sdpa, - n_kv_head, - head_dim, - style=rope_style), - mlp_class(hidden_size, linear_units, dropout_rate, activation, - mlp_bias), - dropout_rate, - normalize_before, - layer_norm_type=layer_norm_type, - norm_eps=norm_eps, - rms_norm_offset=rms_norm_offset, - ) for _ in range(self.num_blocks) - ]) - self.pre_norm = normalize_before - self.final_norm: Optional[torch.nn.Module] = None - if self.pre_norm: - norm_class = WENET_NORM_CLASSES[layer_norm_type] - if layer_norm_type == "rms_norm": - norm_class = partial( - norm_class, - add_unit_offset=rms_norm_offset, - ) - self.final_norm = norm_class(hidden_size, eps=norm_eps) - - self.n_kv_head = n_kv_head - self.head_dim = head_dim - self._hidden_size = hidden_size - self.use_sdpa = use_sdpa - self.gradient_checkpointing = gradient_checkpointing - - def forward( - self, - input: torch.Tensor, - att_mask: torch.Tensor, - input_position: Union[int, torch.Tensor] = 0, - kv_caches: Optional[List[T_CACHE]] = None, - ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: - xs, pos_emb = self.pos_enc(input, offset=input_position) - if self.use_sdpa: - att_mask = mask_to_bias(att_mask, xs.dtype) - - if self.gradient_checkpointing and self.training: - xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb) - else: - xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb, - kv_caches) - if self.pre_norm and self.final_norm is not None: - xs = self.final_norm(xs) - return xs, kv_caches - - def forward_layers( - self, - xs: torch.Tensor, - att_mask: torch.Tensor, - pos_emb: torch.Tensor, - kv_caches: Optional[List[T_CACHE]] = None, - ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: - if self.training: - for (i, layer) in enumerate(self.decoders): - xs, _, _, _ = layer(xs, att_mask, pos_emb) - new_kv_caches = kv_caches - else: - assert kv_caches is not None - new_kv_caches = [] - for (i, layer) in enumerate(self.decoders): - xs, _, new_kv_cache, _ = layer(xs, - att_mask, - pos_emb, - att_cache=(kv_caches[i][0], - kv_caches[i][1])) - new_kv_caches.append(new_kv_cache) - - return xs, new_kv_caches - - @torch.jit.ignore(drop=True) - def forward_layers_checkpointed(self, xs: torch.Tensor, - att_mask: torch.Tensor, - pos_emb: torch.Tensor) -> torch.Tensor: - for layer in self.decoders: - xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask, - pos_emb) - return xs - - @property - def hidden_size(self): - return self._hidden_size diff --git a/wenet/LLM/sampler.py b/wenet/LLM/sampler.py deleted file mode 100644 index 19f0d5cda..000000000 --- a/wenet/LLM/sampler.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Union -import torch - - -# modified from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L26 -@torch.no_grad() -def sampler( - logits: torch.Tensor, - temperatures: Union[torch.Tensor, None], - top_ps: torch.Tensor, - top_ks: torch.Tensor, -) -> torch.Tensor: - assert logits.size(1) == 1 - logits = logits.squeeze(1) # (batch_size, vocab_size) - if temperatures is None: - return torch.argmax(logits, dim=-1).squeeze(dim=-1) - - # Apply temperature scaling. - logits.div_(temperatures.unsqueeze(dim=1)) - - # Calculate probabilities with softmax. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - - # Apply top-p, top-k. - probs_sum = torch.cumsum(probs_sort, dim=-1) - top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) - probs_sort = torch.where(top_ps_mask, 0, probs_sort) - - top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) - top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) - top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) - probs_sort = torch.where(top_ks_mask, 0, probs_sort) - - # Re-normalization. - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - probs = torch.gather(probs_sort, - dim=-1, - index=torch.argsort(probs_idx, dim=-1)) - - next_token_ids = torch.multinomial(probs, num_samples=1, - replacement=True).squeeze(dim=-1) - return next_token_ids diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index cbcf2f528..6ef0f4b84 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -21,8 +21,6 @@ from wenet.paraformer.cif import Cif from wenet.paraformer.layers import SanmDecoder, SanmEncoder from wenet.paraformer.paraformer import Paraformer, Predictor -from wenet.LLM.causallm_model import CausalLM -from wenet.LLM.decoder import DecoderOnly from wenet.ssl.init_model import WENET_SSL_MODEL_CLASS from wenet.transducer.joint import TransducerJoint from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, @@ -43,7 +41,6 @@ from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules - WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, "conformer": ConformerEncoder, @@ -85,7 +82,6 @@ "k2_model": K2Model, "transducer": Transducer, 'paraformer': Paraformer, - 'causal_llm': CausalLM, } @@ -172,30 +168,11 @@ def init_speech_model(args, configs): return model, configs -def init_causal_llm(configs): - vocab_size = configs['output_dim'] - assert configs['decoder'] == 'decoder_only' - assert configs['model'] == 'causal_lm' - decoder_only = DecoderOnly(**configs['decoder_conf']) - - model = CausalLM( - vocab_size, - decoder_only, - **configs['model_conf'], - special_tokens=configs.get('tokenizer_conf', - {}).get('special_tokens', None), - ) - return model, configs - - def init_model(args, configs): model_type = configs.get('model', 'asr_model') configs['model'] = model_type - if model_type == 'causal_lm': - model, configs = init_causal_llm(configs) - else: - model, configs = init_speech_model(args, configs) + model, configs = init_speech_model(args, configs) if hasattr(args, 'use_lora') and args.use_lora: inject_lora_to_model(model, configs['lora_conf'])