diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index b37308be4..6ffa1d987 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -3,6 +3,153 @@ import triton.language as tl from .utils import _strides, get_padded_headsize +@triton.jit +def rotary_kernel_splitk( + # Dimensions of X + X, # tensor being rotated. Has shape (batch (z), seqlen (s), group (g), head (h), head_dim (d)) + seqlen_x, # seqlen of the x dim. shape is (batch (z), ) + head_dim, + rotary_dim, # size of embedding space we end up rotating + + # COS/SIN and Offsetting Into It + COS, # tensor of shape (seqlen (m), ro_dim // 2) + SIN, # tensor of shape (seqlen (m), ro_dim // 2) + SEQLEN_OFFSET, # we use this as an offset into COS and SIN to apply the correct rotation + SEQLEN_OFFSET_IS_TENSOR: tl.constexpr, # if seqlen_offset is a tensor it has shape (num_batch, ) + + # PID Offsets + batch_pid: tl.constexpr, # pid for batch + start_m: tl.constexpr, # the token idx the current M_BLOCK starts at. + group_pid: tl.constexpr, # pid for group + head_pid: tl.constexpr, # pid to access head + + # Strides + stride_batch: tl.constexpr, + stride_m: tl.constexpr, + stride_group: tl.constexpr, + stride_head: tl.constexpr, + stride_headdim: tl.constexpr, + + # Misc + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + TRANSPOSE: tl.constexpr, + + # Meta-parameters + BLOCK_M: tl.constexpr, # block size to access chunks of tokens (# of tokens simultaneously) + BLOCK_K: tl.constexpr, # block size to access chunks of headdim (# of dimensions processed) +): + """ + Note: + - for K in splitk let BLOCK_M = BLOCK_N, and start_m=start_n + """ + # pdb.set_trace() + range_m = start_m + tl.arange(0, BLOCK_M) + range_d = tl.arange(0, BLOCK_K) + + x_ptr = X + (batch_pid * stride_batch) + (group_pid * stride_group) + (head_pid * stride_head) # pointer to x block + x_mask = (range_m < seqlen_x)[:, None] & (range_d < rotary_dim)[None, :] + + ro_dim_half = rotary_dim // 2 # length of cos/sin + + if SEQLEN_OFFSET_IS_TENSOR: + seqlen_offset = tl.load(SEQLEN_OFFSET + batch_pid) # a tensor + else: + seqlen_offset = SEQLEN_OFFSET # an int + + # load full x (puts values in cache) + x_range = range_m[:, None]*stride_m + range_d[None, :] + x_mask = (range_m < seqlen_x)[:, None] & (range_d < head_dim)[None, :] + x = tl.load(x_ptr + x_range, mask=x_mask) + + + if not INTERLEAVED: + range_d_half_duplicate = range_d % (rotary_dim // 2) + + x0_range = range_m[:, None]*stride_m + range_d_half_duplicate[None, :]*stride_headdim # BLOCK_M x 1st half of headdim (fast to load) + x1_range = range_m[:, None]*stride_m + range_d_half_duplicate[None, :]*stride_headdim + ro_dim_half # BLOCK_M x 2nd half of headdim (fast to load) + + x0_mask = (range_m < seqlen_x)[:, None] & (range_d_half_duplicate < rotary_dim)[None, :] # Mask for the first half + x1_mask = (range_m < seqlen_x)[:, None] & (range_d_half_duplicate + ro_dim_half < rotary_dim)[None, :] # Mask for the second half + + range_m_cos_sin = range_m + seqlen_offset # offsets cos and sin based on current m position range and seqlen offset + COS = COS + (range_m_cos_sin[:, None] * ro_dim_half + range_d_half_duplicate[None, :]) + SIN = SIN + (range_m_cos_sin[:, None] * ro_dim_half + range_d_half_duplicate[None, :]) + cos = tl.load( + COS, mask=(range_m[:, None] < seqlen_x) & (range_d_half_duplicate[None, :] < ro_dim_half), other=1.0 + ).to(tl.float32) + sin = tl.load( + SIN, mask=(range_m[:, None] < seqlen_x + seqlen_offset) & (range_d_half_duplicate[None, :] < ro_dim_half), other=0.0 + ).to(tl.float32) + if CONJUGATE: + sin = -sin + + x0 = tl.load(x_ptr + x0_range, mask=x0_mask).to(tl.float32) + x1 = tl.load(x_ptr + x1_range, mask=x1_mask).to(tl.float32) + + # Rotate corresponding elements in each half + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + + out = tl.where(range_d[None, :] // ro_dim_half == 0, o0, o1) + + # for all dim not in rotary_dim, leave untouched + out = tl.where(range_d[None, :] < rotary_dim, out, x) + + # transpose the rotated vector + if TRANSPOSE: + out = tl.trans(out) + + return out + + else: + # Interleaved is slow due to x1 load + range_d_swap = range_d + ((range_d + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + + # X Range + x0_range = range_m[:, None]*stride_m + range_d[None, :] # 0, 1, 2, 3, 4, 5, ... (fast to load) + x1_range = range_m[:, None]*stride_m + range_d_swap[None, :] # 1, 0, 3, 2, 5, 4, ... (slow to load) + + # X Masks + x0_mask = (range_m < seqlen_x)[:, None] & (range_d < rotary_dim)[None, :] # Mask for the first half + x1_mask = (range_m < seqlen_x)[:, None] & (range_d_swap < rotary_dim)[None, :] # Mask for the second half + + # Load COS/SIN + range_d_repeat = tl.arange(0, BLOCK_K) // 2 # 0, 0, 1, 1, 2, 2, ... + + range_m_cos_sin = range_m + seqlen_offset + COS = COS + (range_m_cos_sin[:, None] * ro_dim_half + range_d_repeat[None, :]) + SIN = SIN + (range_m_cos_sin[:, None] * ro_dim_half + range_d_repeat[None, :]) + cos = tl.load( + COS, + mask=(range_m[:, None] < seqlen_x) & (range_d_repeat[None, :] < ro_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + SIN, + mask=(range_m[:, None] < seqlen_x) & (range_d_repeat[None, :] < ro_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + + x0 = tl.load(x_ptr + x0_range, mask=x0_mask) + x1 = tl.load(x_ptr + x1_range, mask=x1_mask) + + x0_cos = x0 * cos + x1_sin = x1 * sin + + out = tl.where(range_d[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + + # for all dim not in rotary_dim, leave untouched + out = tl.where(range_d[None, :] < rotary_dim, out, x) + + # transpose the rotated vector + if TRANSPOSE: + out = tl.trans(out) + + return out + @triton.jit def _fwd_kernel_splitK( Q, @@ -16,6 +163,15 @@ def _fwd_kernel_splitK( Cache_seqlens, Cache_batch_idx, Alibi_slopes, + # Rotary + Rotary_cos, + Rotary_sin, + Rotary_dim, + Rotary_interleaved: tl.constexpr, + Rotary_conjugate: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + # Strides stride_qz, stride_qm, stride_qg, @@ -64,12 +220,13 @@ def _fwd_kernel_splitK( ACTUAL_BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, BOUNDS_CHECKS_N: tl.constexpr, - USE_CACHE_SEQLENs: tl.constexpr, + USE_CACHE_SEQLENS: tl.constexpr, USE_CACHE_BATCH_IDX: tl.constexpr, NEW_KV: tl.constexpr, IS_GQA: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_ALIBI: tl.constexpr, + USE_ROTARY: tl.constexpr, ): # Padding PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) @@ -97,7 +254,7 @@ def _fwd_kernel_splitK( alibi_slope = None lo = splitk_idx * BLOCK_N_PER_SPLIT - if USE_CACHE_SEQLENs: + if USE_CACHE_SEQLENS: cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) if NEW_KV: kv_len = cache_seqlen_last_idx + N_CTX_NEW @@ -124,22 +281,55 @@ def _fwd_kernel_splitK( knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g # Determine the starting position for new data in the cache - if USE_CACHE_SEQLENs: + if USE_CACHE_SEQLENS: start_idx = tl.load(Cache_seqlens + off_z) else: start_idx = N_CTX_K - N_CTX_NEW # Copy new Keys for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from K_new - k_new_block = tl.load( - knew_base + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + - (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - other=0 - ) + + # Load from K_new and apply rotary to k + if USE_ROTARY: + k_new_block = rotary_kernel_splitk( + X=K_new, + seqlen_x=N_CTX_NEW, + head_dim=BLOCK_DMODEL, + rotary_dim=Rotary_dim, + + COS=Rotary_cos, + SIN=Rotary_sin, + SEQLEN_OFFSET=Cache_seqlens, + SEQLEN_OFFSET_IS_TENSOR=IS_SEQLEN_OFFSETS_TENSOR, + + batch_pid=off_z, + start_m=i, # current block of tokens in new_k + group_pid=off_g_q, + head_pid=k_head_idx, + + stride_batch=stride_kn_z, # batch_strides if not varlen else 0 + stride_m=stride_kn_n, + stride_group=stride_kn_g, + stride_head=stride_kn_h, + stride_headdim=stride_kn_d, + + INTERLEAVED=Rotary_interleaved, + CONJUGATE=Rotary_conjugate, + TRANSPOSE=True, + + BLOCK_M=BLOCK_N, + BLOCK_K=BLOCK_DMODEL + ) + else: + # Load from K_new + k_new_block = tl.load( + knew_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + + (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + other=0 + ) # Store to K tl.store( @@ -213,9 +403,41 @@ def _fwd_kernel_splitK( # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout - q = tl.load( # noqa: F821 - tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) + # load q: decide if should apply rotary after load + if USE_ROTARY: + q = rotary_kernel_splitk( + X=Q, + seqlen_x=N_CTX_Q, + head_dim=BLOCK_DMODEL, + rotary_dim=Rotary_dim, + + COS=Rotary_cos, + SIN=Rotary_sin, + SEQLEN_OFFSET=Cache_seqlens, + SEQLEN_OFFSET_IS_TENSOR=IS_SEQLEN_OFFSETS_TENSOR, + + batch_pid=off_z, + start_m=start_m*BLOCK_M, + group_pid=off_g_q, + head_pid=off_h_q, + + stride_batch= (stride_qz if not IS_VARLEN else 0), # batch_strides if not varlen else 0 + stride_m=stride_qm, + stride_group=stride_qg, + stride_head=stride_qh, + stride_headdim=stride_qd, + + INTERLEAVED=Rotary_interleaved, + CONJUGATE=Rotary_conjugate, + TRANSPOSE=False, + + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_DMODEL + ) + else: + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) q = (q * qk_scale).to(q.dtype) if PADDED_HEAD: q = tl.where(d_mask[None, :], q, 0.0) @@ -339,8 +561,8 @@ def load_k_v_group( V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) # -- load k, v -- - k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) - v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()).to(tl.float32) + v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()).to(tl.float32) return k, v @@ -540,7 +762,11 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int: split_k = max(split_k, 1) return split_k -def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new): +def attention_decode_forward_triton_impl(q, k, v, + sm_scale, causal, alibi_slopes, + layout, cache_seqlens, cache_batch_idx, + new_kv, k_new, v_new, + rotary_cos, rotary_sin, rotary_dim, rotary_interleaved, rotary_conjugate): # kernel config BLOCK_M = 16 BLOCK_N = 64 @@ -620,6 +846,13 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes Cache_seqlens=cache_seqlens, Cache_batch_idx=cache_batch_idx, Alibi_slopes=alibi_slopes, + Rotary_cos=rotary_cos, + Rotary_sin=rotary_sin, + Rotary_dim=rotary_dim, + Rotary_interleaved = rotary_interleaved, + Rotary_conjugate = rotary_conjugate, + IS_SEQLEN_OFFSETS_TENSOR = isinstance(cache_seqlens, torch.Tensor), + IS_VARLEN = False, **_strides(q, "qz", "qm", "qg", "qh", "qd"), **_strides(k, "kz", "kn", "kg", "kh", "kd"), **_strides(v, "vz", "vn", "vg", "vh", "vd"), @@ -641,12 +874,13 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes BLOCK_DMODEL=dim_padded, ACTUAL_BLOCK_DMODEL=dim_k, BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, - USE_CACHE_SEQLENs=use_cache_seqlens, + USE_CACHE_SEQLENS=use_cache_seqlens, USE_CACHE_BATCH_IDX=cache_batch_idx is not None, NEW_KV=new_kv, IS_GQA=is_gqa, IS_CAUSAL=causal, USE_ALIBI=False if alibi_slopes is None else True, + USE_ROTARY= False if rotary_cos is None or rotary_sin is None else True, num_warps=num_warps, num_stages=1, ) diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 59a306d5d..35e9aaa23 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -6,8 +6,11 @@ from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import MetaData, get_shape_from_layout, DEBUG +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') +ENABLE_FUSED_ROTARY = os.environ.get('FLASH_ATTENTION_TRITON_AMD_ENABLE_FUSED_ROTARY', '0').lower() in ('1', 'true', 'yes') def fwd(q, k, @@ -516,6 +519,49 @@ def fwd_kvcache( batch, _ , nheads_q, _= q.shape metadata.need_alibi(alibi_slopes, batch, nheads_q) + # rotary boolean + apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin) + if apply_rotary: + _, dim = rotary_cos.shape + rotary_dim = dim * 2 + metadata.need_rotary(rotary_dim, rotary_sin, rotary_cos, rotary_interleaved) + + if not ENABLE_FUSED_ROTARY: + # Non-fused rotary kernel + if apply_rotary: + if metadata.causal: # NOTE: when local support is added. Add `or metadata.local` + q_ro = apply_rotary_emb( + q, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=metadata.max_seqlens_q, + ) + k_ro = apply_rotary_emb( + metadata.k_new, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + + q, metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) + + # nullify rotary parameters so that the fused rotary implementation is not executed within the triton decode fwd kernel + metadata.need_rotary(0, None, None, False) + # launch kernel # TODO: pass output as an arg. Maybe we are copying output which is causing slow down output, softmax_lse = attention_decode_forward_triton_impl( @@ -531,5 +577,10 @@ def fwd_kvcache( metadata.new_kv, metadata.k_new, metadata.v_new, + metadata.rotary_cos, + metadata.rotary_sin, + metadata.rotary_dim, + metadata.rotary_interleaved, + metadata.rotary_conjunction ) return output, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 530455063..a3b3e925e 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -27,6 +27,11 @@ class MetaData(): dropout_p, return_scores= 0.0, False # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. use_exp2 = False + rotary_dim = 0 + rotary_sin = None + rotary_cos = None + rotary_interleaved = False + rotary_conjunction = False def __repr__(self) -> str: @@ -85,6 +90,13 @@ def need_alibi(self, alibi_slopes, batch, nheads): def need_causal(self): self.causal = True + def need_rotary(self, rotary_dim, sin, cos, rotary_interleaved, rotary_conjunction=False): + self.rotary_dim = rotary_dim + self.rotary_sin = sin + self.rotary_cos = cos + self.rotary_interleaved = rotary_interleaved + self.rotary_conjunction = rotary_conjunction + def need_dropout(self, dropout_p, return_scores): self.dropout_p = dropout_p self.return_scores = return_scores diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index d64246f95..f6844b435 100644 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1835,11 +1835,12 @@ def test_flash_attn_varlen_causal( # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_splits", [1, 0]) # @pytest.mark.parametrize("num_splits", [1]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa"]) +# @pytest.mark.parametrize("mha_type", ["gqa"]) @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("alibi", [False, True]) @@ -1850,10 +1851,10 @@ def test_flash_attn_varlen_causal( # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -# @pytest.mark.parametrize("rotary_interleaved", [False, True]) -@pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_interleaved", [True]) # @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -@pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("rotary_fraction", [0.5, 1.0]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) @pytest.mark.parametrize("paged_kv_block_size", [None]) @@ -1907,9 +1908,6 @@ def test_flash_attn_kvcache( if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if rotary_interleaved == True or rotary_fraction > 0.0: - pytest.skip("rotary embedding not supported on AMD's Triton Backend yet") if has_leftpad == True: pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet") @@ -1924,7 +1922,7 @@ def test_flash_attn_kvcache( device = "cuda" # set seed torch.random.manual_seed(0) - batch_size = 2 + batch_size = 4 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # rotary_dim must be a multiple of 16, and must be <= d