Skip to content

Commit

Permalink
fix minor error (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun authored Aug 5, 2024
1 parent b6eb915 commit 8a7fcea
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions huggingface_model/internlm/internlm_7b/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,10 @@ def forward(
cumulative_len=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=0.0,
).unsqueeze(0)
)
else:
attn_output = hf_q_k_v_without_cu_seqlens(
query_states, key_states, value_states, dropout_p=0, softmax_scale=None, causal=True,
query_states, key_states, value_states, dropout_p=0.0, softmax_scale=None, causal=True,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down Expand Up @@ -541,7 +541,7 @@ def _flash_attention_forward(
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = varlen_flash_attn(
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
Expand All @@ -556,7 +556,7 @@ def _flash_attention_forward(

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_wo_mask(
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)

Expand Down

0 comments on commit 8a7fcea

Please sign in to comment.