Skip to content

Commit

Permalink
fix sequence_parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 28, 2024
1 parent 6cd1bcd commit 5d03d58
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
7 changes: 3 additions & 4 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ def attention_prefill_backward_triton_impl(
cu_seqlens_k,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool
use_exp2: bool,
sequence_parallel = False,
):
if DEBUG:
print()
Expand All @@ -473,6 +474,7 @@ def attention_prefill_backward_triton_impl(
print("max_seqlen_q:", max_seqlen_q)
print("max_seqlen_k:", max_seqlen_k)
print("use_exp2:", use_exp2)
print("sequence_parallel:", sequence_parallel)

# make contigious
q = q.contiguous()
Expand Down Expand Up @@ -502,9 +504,6 @@ def attention_prefill_backward_triton_impl(
num_stages = 1
waves_per_eu = 1

# configs
sequence_parallel = False

# divide up the problem
num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M)
num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N)
Expand Down
6 changes: 2 additions & 4 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ def fwd(q,
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k.shape[1]
metadata.layout = "bshd"
metadata.use_exp2 = False
if return_softmax:
metadata.return_encoded_softmax = True
metadata.return_scores = True

batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout)

Expand Down Expand Up @@ -279,9 +278,8 @@ def varlen_fwd(

# Setup metadata
metadata = MetaData(sm_scale=softmax_scale)
metadata.use_exp2 = False
if return_softmax:
metadata.return_encoded_softmax = True
metadata.return_scores = True
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata

# get shapes
Expand Down
1 change: 0 additions & 1 deletion flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class MetaData():
dropout_p, return_scores= 0.0, False
# NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
use_exp2 = False
return_encoded_softmax = False


def __repr__(self) -> str:
Expand Down

0 comments on commit 5d03d58

Please sign in to comment.