Skip to content

Commit

Permalink
feat: added ability to switch between non-fused (passing) and fused (…
Browse files Browse the repository at this point in the history
…failing) rotary
  • Loading branch information
alexkranias-amd committed Nov 14, 2024
1 parent e02ceee commit 6ead05a
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 52 deletions.
272 changes: 253 additions & 19 deletions flash_attn/flash_attn_triton_amd/fwd_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,153 @@
import triton.language as tl
from .utils import _strides, get_padded_headsize

@triton.jit
def rotary_kernel_splitk(
# Dimensions of X
X, # tensor being rotated. Has shape (batch (z), seqlen (s), group (g), head (h), head_dim (d))
seqlen_x, # seqlen of the x dim. shape is (batch (z), )
head_dim,
rotary_dim, # size of embedding space we end up rotating

# COS/SIN and Offsetting Into It
COS, # tensor of shape (seqlen (m), ro_dim // 2)
SIN, # tensor of shape (seqlen (m), ro_dim // 2)
SEQLEN_OFFSET, # we use this as an offset into COS and SIN to apply the correct rotation
SEQLEN_OFFSET_IS_TENSOR: tl.constexpr, # if seqlen_offset is a tensor it has shape (num_batch, )

# PID Offsets
batch_pid: tl.constexpr, # pid for batch
start_m: tl.constexpr, # the token idx the current M_BLOCK starts at.
group_pid: tl.constexpr, # pid for group
head_pid: tl.constexpr, # pid to access head

# Strides
stride_batch: tl.constexpr,
stride_m: tl.constexpr,
stride_group: tl.constexpr,
stride_head: tl.constexpr,
stride_headdim: tl.constexpr,

# Misc
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
TRANSPOSE: tl.constexpr,

# Meta-parameters
BLOCK_M: tl.constexpr, # block size to access chunks of tokens (# of tokens simultaneously)
BLOCK_K: tl.constexpr, # block size to access chunks of headdim (# of dimensions processed)
):
"""
Note:
- for K in splitk let BLOCK_M = BLOCK_N, and start_m=start_n
"""
# pdb.set_trace()
range_m = start_m + tl.arange(0, BLOCK_M)
range_d = tl.arange(0, BLOCK_K)

x_ptr = X + (batch_pid * stride_batch) + (group_pid * stride_group) + (head_pid * stride_head) # pointer to x block
x_mask = (range_m < seqlen_x)[:, None] & (range_d < rotary_dim)[None, :]

ro_dim_half = rotary_dim // 2 # length of cos/sin

if SEQLEN_OFFSET_IS_TENSOR:
seqlen_offset = tl.load(SEQLEN_OFFSET + batch_pid) # a tensor
else:
seqlen_offset = SEQLEN_OFFSET # an int

# load full x (puts values in cache)
x_range = range_m[:, None]*stride_m + range_d[None, :]
x_mask = (range_m < seqlen_x)[:, None] & (range_d < head_dim)[None, :]
x = tl.load(x_ptr + x_range, mask=x_mask)


if not INTERLEAVED:
range_d_half_duplicate = range_d % (rotary_dim // 2)

x0_range = range_m[:, None]*stride_m + range_d_half_duplicate[None, :]*stride_headdim # BLOCK_M x 1st half of headdim (fast to load)
x1_range = range_m[:, None]*stride_m + range_d_half_duplicate[None, :]*stride_headdim + ro_dim_half # BLOCK_M x 2nd half of headdim (fast to load)

x0_mask = (range_m < seqlen_x)[:, None] & (range_d_half_duplicate < rotary_dim)[None, :] # Mask for the first half
x1_mask = (range_m < seqlen_x)[:, None] & (range_d_half_duplicate + ro_dim_half < rotary_dim)[None, :] # Mask for the second half

range_m_cos_sin = range_m + seqlen_offset # offsets cos and sin based on current m position range and seqlen offset
COS = COS + (range_m_cos_sin[:, None] * ro_dim_half + range_d_half_duplicate[None, :])
SIN = SIN + (range_m_cos_sin[:, None] * ro_dim_half + range_d_half_duplicate[None, :])
cos = tl.load(
COS, mask=(range_m[:, None] < seqlen_x) & (range_d_half_duplicate[None, :] < ro_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(range_m[:, None] < seqlen_x + seqlen_offset) & (range_d_half_duplicate[None, :] < ro_dim_half), other=0.0
).to(tl.float32)
if CONJUGATE:
sin = -sin

x0 = tl.load(x_ptr + x0_range, mask=x0_mask).to(tl.float32)
x1 = tl.load(x_ptr + x1_range, mask=x1_mask).to(tl.float32)

# Rotate corresponding elements in each half
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos

out = tl.where(range_d[None, :] // ro_dim_half == 0, o0, o1)

# for all dim not in rotary_dim, leave untouched
out = tl.where(range_d[None, :] < rotary_dim, out, x)

# transpose the rotated vector
if TRANSPOSE:
out = tl.trans(out)

return out

else:
# Interleaved is slow due to x1 load
range_d_swap = range_d + ((range_d + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...

# X Range
x0_range = range_m[:, None]*stride_m + range_d[None, :] # 0, 1, 2, 3, 4, 5, ... (fast to load)
x1_range = range_m[:, None]*stride_m + range_d_swap[None, :] # 1, 0, 3, 2, 5, 4, ... (slow to load)

# X Masks
x0_mask = (range_m < seqlen_x)[:, None] & (range_d < rotary_dim)[None, :] # Mask for the first half
x1_mask = (range_m < seqlen_x)[:, None] & (range_d_swap < rotary_dim)[None, :] # Mask for the second half

# Load COS/SIN
range_d_repeat = tl.arange(0, BLOCK_K) // 2 # 0, 0, 1, 1, 2, 2, ...

range_m_cos_sin = range_m + seqlen_offset
COS = COS + (range_m_cos_sin[:, None] * ro_dim_half + range_d_repeat[None, :])
SIN = SIN + (range_m_cos_sin[:, None] * ro_dim_half + range_d_repeat[None, :])
cos = tl.load(
COS,
mask=(range_m[:, None] < seqlen_x) & (range_d_repeat[None, :] < ro_dim_half),
other=1.0,
).to(tl.float32)
sin = tl.load(
SIN,
mask=(range_m[:, None] < seqlen_x) & (range_d_repeat[None, :] < ro_dim_half),
other=0.0,
).to(tl.float32)
if CONJUGATE:
sin = -sin

x0 = tl.load(x_ptr + x0_range, mask=x0_mask)
x1 = tl.load(x_ptr + x1_range, mask=x1_mask)

x0_cos = x0 * cos
x1_sin = x1 * sin

out = tl.where(range_d[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)

# for all dim not in rotary_dim, leave untouched
out = tl.where(range_d[None, :] < rotary_dim, out, x)

# transpose the rotated vector
if TRANSPOSE:
out = tl.trans(out)

return out

@triton.jit
def _fwd_kernel_splitK(
Q,
Expand All @@ -16,6 +163,15 @@ def _fwd_kernel_splitK(
Cache_seqlens,
Cache_batch_idx,
Alibi_slopes,
# Rotary
Rotary_cos,
Rotary_sin,
Rotary_dim,
Rotary_interleaved: tl.constexpr,
Rotary_conjugate: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
# Strides
stride_qz,
stride_qm,
stride_qg,
Expand Down Expand Up @@ -64,12 +220,13 @@ def _fwd_kernel_splitK(
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
BOUNDS_CHECKS_N: tl.constexpr,
USE_CACHE_SEQLENs: tl.constexpr,
USE_CACHE_SEQLENS: tl.constexpr,
USE_CACHE_BATCH_IDX: tl.constexpr,
NEW_KV: tl.constexpr,
IS_GQA: tl.constexpr,
IS_CAUSAL: tl.constexpr,
USE_ALIBI: tl.constexpr,
USE_ROTARY: tl.constexpr,
):
# Padding
PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
Expand Down Expand Up @@ -97,7 +254,7 @@ def _fwd_kernel_splitK(
alibi_slope = None

lo = splitk_idx * BLOCK_N_PER_SPLIT
if USE_CACHE_SEQLENs:
if USE_CACHE_SEQLENS:
cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z)
if NEW_KV:
kv_len = cache_seqlen_last_idx + N_CTX_NEW
Expand All @@ -124,22 +281,55 @@ def _fwd_kernel_splitK(
knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g

# Determine the starting position for new data in the cache
if USE_CACHE_SEQLENs:
if USE_CACHE_SEQLENS:
start_idx = tl.load(Cache_seqlens + off_z)
else:
start_idx = N_CTX_K - N_CTX_NEW

# Copy new Keys
for i in range(0, N_CTX_NEW, BLOCK_N):
# Load from K_new
k_new_block = tl.load(
knew_base +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
other=0
)

# Load from K_new and apply rotary to k
if USE_ROTARY:
k_new_block = rotary_kernel_splitk(
X=K_new,
seqlen_x=N_CTX_NEW,
head_dim=BLOCK_DMODEL,
rotary_dim=Rotary_dim,

COS=Rotary_cos,
SIN=Rotary_sin,
SEQLEN_OFFSET=Cache_seqlens,
SEQLEN_OFFSET_IS_TENSOR=IS_SEQLEN_OFFSETS_TENSOR,

batch_pid=off_z,
start_m=i, # current block of tokens in new_k
group_pid=off_g_q,
head_pid=off_h_q,

stride_batch= stride_kz, # batch_strides if not varlen else 0
stride_m=stride_kn,
stride_group=stride_kg,
stride_head=stride_kh,
stride_headdim=stride_kd,

INTERLEAVED=Rotary_interleaved,
CONJUGATE=Rotary_conjugate,
TRANSPOSE=True,

BLOCK_M=BLOCK_N,
BLOCK_K=BLOCK_DMODEL
)
else:
# Load from K_new
k_new_block = tl.load(
knew_base +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
other=0
)

# Store to K
tl.store(
Expand Down Expand Up @@ -213,9 +403,41 @@ def _fwd_kernel_splitK(
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
# load q: decide if should apply rotary after load
if USE_ROTARY:
q = rotary_kernel_splitk(
X=Q,
seqlen_x=N_CTX_Q,
head_dim=BLOCK_DMODEL,
rotary_dim=Rotary_dim,

COS=Rotary_cos,
SIN=Rotary_sin,
SEQLEN_OFFSET=Cache_seqlens,
SEQLEN_OFFSET_IS_TENSOR=IS_SEQLEN_OFFSETS_TENSOR,

batch_pid=off_z,
start_m=start_m*BLOCK_M,
group_pid=off_g_q,
head_pid=off_h_q,

stride_batch= (stride_kz if not IS_VARLEN else 0), # batch_strides if not varlen else 0
stride_m=stride_kn,
stride_group=stride_kg,
stride_head=stride_kh,
stride_headdim=stride_kd,

INTERLEAVED=Rotary_interleaved,
CONJUGATE=Rotary_conjugate,
TRANSPOSE=False,

BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_DMODEL
)
else:
# load q: it will stay in SRAM throughout
q = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
q = (q * qk_scale).to(q.dtype)
if PADDED_HEAD:
q = tl.where(d_mask[None, :], q, 0.0)
Expand Down Expand Up @@ -339,8 +561,8 @@ def load_k_v_group(
V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id))

# -- load k, v --
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ())
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ())
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()).to(tl.float32)
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()).to(tl.float32)

return k, v

Expand Down Expand Up @@ -540,7 +762,11 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
split_k = max(split_k, 1)
return split_k

def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new):
def attention_decode_forward_triton_impl(q, k, v,
sm_scale, causal, alibi_slopes,
layout, cache_seqlens, cache_batch_idx,
new_kv, k_new, v_new,
rotary_cos, rotary_sin, rotary_dim, rotary_interleaved, rotary_conjugate):
# kernel config
BLOCK_M = 16
BLOCK_N = 64
Expand Down Expand Up @@ -620,6 +846,13 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes
Cache_seqlens=cache_seqlens,
Cache_batch_idx=cache_batch_idx,
Alibi_slopes=alibi_slopes,
Rotary_cos=rotary_cos,
Rotary_sin=rotary_sin,
Rotary_dim=rotary_dim,
Rotary_interleaved = rotary_interleaved,
Rotary_conjugate = rotary_conjugate,
IS_SEQLEN_OFFSETS_TENSOR = isinstance(cache_seqlens, torch.Tensor),
IS_VARLEN = False,
**_strides(q, "qz", "qm", "qg", "qh", "qd"),
**_strides(k, "kz", "kn", "kg", "kh", "kd"),
**_strides(v, "vz", "vn", "vg", "vh", "vd"),
Expand All @@ -641,12 +874,13 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes
BLOCK_DMODEL=dim_padded,
ACTUAL_BLOCK_DMODEL=dim_k,
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens,
USE_CACHE_SEQLENs=use_cache_seqlens,
USE_CACHE_SEQLENS=use_cache_seqlens,
USE_CACHE_BATCH_IDX=cache_batch_idx is not None,
NEW_KV=new_kv,
IS_GQA=is_gqa,
IS_CAUSAL=causal,
USE_ALIBI=False if alibi_slopes is None else True,
USE_ROTARY= False if rotary_cos is None or rotary_sin is None else True,
num_warps=num_warps,
num_stages=1,
)
Expand Down
Loading

0 comments on commit 6ead05a

Please sign in to comment.