Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Queued PR] Port fixes from 0.7.2b #56

Draft
wants to merge 4 commits into
base: xinyazhang/meff-nonsquare_causal
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/aotriton_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def cast_dtype(dtype):

def mk_aotensor(q, if_empty_then_like=None):
rank = len(q.shape) if q is not None else len(if_empty_then_like.shape)
if q is not None and q.numel() == 1:
if q is not None and len(q.shape) == 1 and q.numel() == 1:
if PASS_PHILOX_AS_TENSOR:
return T0(q.data_ptr(), cast_dtype(q.dtype))
else:
Expand Down
22 changes: 22 additions & 0 deletions test/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,28 @@ def test_op_bwd_with_matrix_bias(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, sm_
'''
_do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type)

def test_large_bf16_nan_values():
q = torch.full((1, 1, 1, 16), 133120.0, dtype=torch.bfloat16, device="cuda")
k = torch.full((1, 1, 1, 16), 133120.0, dtype=torch.bfloat16, device="cuda")
v = torch.full((1, 1, 1, 16), 133120.0, dtype=torch.bfloat16, device="cuda")
b = None
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel(SDPBackend.MATH):
out = scaled_dot_product_attention(q, k, v)
print(out)

causal = False
sm_scale = 0.125
dropout_p = 0
ext = AttentionExtraArgs(return_encoded_softmax=causal,
autotune=False,
return_autotune=False)
tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, ext)

print(tri_out)
assert not torch.isnan(tri_out).any(), "Output should not contain NaNs!"

def main_npz():
SKIP_DK_DV = False
SKIP_DQ = False
Expand Down
2 changes: 1 addition & 1 deletion tritonsrc/fwd_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def attn_fwd(
else:
alibi_slope = None

off_zh = batch_index * num_head_q + off_h_q
off_zh = off_z * num_head_q + off_h_q
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base + off_zh * max_seqlen_q * max_seqlen_k
else:
Expand Down
14 changes: 9 additions & 5 deletions tritonsrc/fwd_kernel_inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,22 @@ def attn_fwd_inner(
global_n_positions)
qk += alibi_block * bias_scale

# softmax
m_ij = tl.maximum(m_i, qk_scale * tl.max(qk, 1))
# FIXME: when sm_scale = 0.0 and MASK_STEPS/CAUSAL = True
# qk * qk_scale = nan
p = tl.math.exp2(qk * qk_scale - m_ij[:, None])
# This has softmax approach has numerical errors for large inputs:
# See: https://github.com/ROCm/aotriton/issues/54
# m_ij = tl.maximum(m_i, qk_scale * tl.max(qk, 1))
# p = tl.math.exp2(qk * qk_scale - m_ij[:, None])
m_ij = tl.maximum(m_i, tl.max(qk, 1))
p = tl.math.exp2(qk_scale * (qk - m_ij[:, None]))

# When sm_scale = 0.0 and MASK_STEPS/CAUSAL = True
# qk * qk_scale = -inf * 0.0 = nan
if MASK_STEPS or CAUSAL:
if qk_scale == 0.0:
p = tl.where(libdevice.isnan(p), 0.0, p)

# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
m_ij = m_ij * qk_scale
if ENABLE_DROPOUT:
philox_offset = batch_philox_offset + start_m * BLOCK_M * max_seqlen_k + start_n
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, max_seqlen_k)
Expand Down
22 changes: 22 additions & 0 deletions tritonsrc/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,28 @@ def test_gqa(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropo
bias_type = None
_do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type)

def test_large_bf16_nan_values():
q = torch.full((1, 1, 1, 16), 133120.0, dtype=torch.bfloat16, device="cuda")
k = torch.full((1, 1, 1, 16), 133120.0, dtype=torch.bfloat16, device="cuda")
v = torch.full((1, 1, 1, 16), 133120.0, dtype=torch.bfloat16, device="cuda")
b = None
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel(SDPBackend.MATH):
out = scaled_dot_product_attention(q, k, v)
print(out)

causal = False
sm_scale = 0.125
dropout_p = 0
ext = AttentionExtraArgs(return_encoded_softmax=causal,
autotune=False,
return_autotune=False)
tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, ext)

print(tri_out)
assert not torch.isnan(tri_out).any(), "Output should not contain NaNs!"

def main_npz():
SKIP_DK_DV = False
SKIP_DQ = False
Expand Down
2 changes: 1 addition & 1 deletion v2src/flash/attn_bwd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ attn_bwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size,
v,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q.size(0),
cu_seqlens_q.size(0) - 1,
max_seqlen_q,
max_seqlen_k,
b,
Expand Down