Skip to content

Commit

Permalink
fix flash att in generate
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 28, 2024
1 parent 79bafa3 commit a9a7f7b
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions wenet/LLM/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def tie_or_clone_weights(self, jit_mode: bool):
self.out.weight = self.embed.weight
# TODO(Mddct): whether to deal bias for other llm model

@torch.jit.ignore(drop=True)
@torch.no_grad()
@torch.jit.unused
@torch.inference_mode()
def generate(
self,
prompts_tokens: List[List[int]],
Expand Down Expand Up @@ -113,10 +113,12 @@ def generate(
# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
IGNORE_ID,
dtype=torch.int64)
dtype=torch.int64,
device=device)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
IGNORE_ID,
dtype=torch.int64)
dtype=torch.int64,
device=device)
# right padding
for i, p in enumerate(prompts_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
Expand All @@ -131,7 +133,8 @@ def generate(
dtype=torch.bool)
mask_tensor = torch.tril(mask_tensor).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
att_mask = curr_mask_tensor.squeeze(1)
att_mask = curr_mask_tensor.squeeze(
1)[:, :min_prompt_len, :min_prompt_len]
output_positions_tensor = torch.LongTensor([min_prompt_len - 1
]).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
Expand All @@ -142,15 +145,17 @@ def generate(
dtype=torch.int64).to(device)

input_token_embeding = self.embed(input_token_ids_tensor)
input_positions_tensor = torch.tensor([0] *
len(prompts_tokens)).to(device)
offset = torch.tensor([0] * len(prompts_tokens)).to(device)
input_offset = offset
# Prefill up to min_prompt_len tokens, then treat other prefill as
# decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
decoder_out, kv_caches, = self.decoder(input_token_embeding,
att_mask,
input_positions_tensor,
kv_caches)
decoder_out, kv_caches, = self.decoder(
input_token_embeding,
att_mask,
input_offset,
kv_caches,
)
decoder_out = self.out(decoder_out)
decoder_out = decoder_out.index_select(1, output_positions_tensor)
next_token_ids = sampler(
Expand All @@ -169,12 +174,16 @@ def generate(

input_token_ids_tensor = output_token_ids
input_token_embeding = self.embed(input_token_ids_tensor)

input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(
2, input_positions_tensor)
att_mask = curr_mask_tensor.squeeze(1)
att_mask = curr_mask_tensor.squeeze(1)[:, :output_index +
1, :output_index + 1]

output_positions_tensor = torch.tensor(
0, dtype=torch.int64).to(device)
input_offset = offset + output_index.unsqueeze(-1)
output_index = output_index + 1

token_ids = token_ids_tensor.tolist()
Expand Down

0 comments on commit a9a7f7b

Please sign in to comment.