Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Rotary Positional Embeddings (both non-fused and fused kernel) #99

Open
wants to merge 7 commits into
base: main_perf
Choose a base branch
from
272 changes: 253 additions & 19 deletions flash_attn/flash_attn_triton_amd/fwd_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,153 @@
import triton.language as tl
from .utils import _strides, get_padded_headsize

@triton.jit
def rotary_kernel_splitk(
# Dimensions of X
X, # tensor being rotated. Has shape (batch (z), seqlen (s), group (g), head (h), head_dim (d))
seqlen_x, # seqlen of the x dim. shape is (batch (z), )
head_dim,
rotary_dim, # size of embedding space we end up rotating

# COS/SIN and Offsetting Into It
COS, # tensor of shape (seqlen (m), ro_dim // 2)
SIN, # tensor of shape (seqlen (m), ro_dim // 2)
SEQLEN_OFFSET, # we use this as an offset into COS and SIN to apply the correct rotation
SEQLEN_OFFSET_IS_TENSOR: tl.constexpr, # if seqlen_offset is a tensor it has shape (num_batch, )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do have to versions of SEQLEN_OFFSET? It seems an int and tensor which we do a load


# PID Offsets
batch_pid: tl.constexpr, # pid for batch
start_m: tl.constexpr, # the token idx the current M_BLOCK starts at.
group_pid: tl.constexpr, # pid for group
head_pid: tl.constexpr, # pid to access head

# Strides
stride_batch: tl.constexpr,
stride_m: tl.constexpr,
stride_group: tl.constexpr,
stride_head: tl.constexpr,
stride_headdim: tl.constexpr,

# Misc
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
TRANSPOSE: tl.constexpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do the transpose inside the rotary function.. It probably makes sense for the caller to do that if it makes sense for them


# Meta-parameters
BLOCK_M: tl.constexpr, # block size to access chunks of tokens (# of tokens simultaneously)
BLOCK_K: tl.constexpr, # block size to access chunks of headdim (# of dimensions processed)
):
"""
Note:
- for K in splitk let BLOCK_M = BLOCK_N, and start_m=start_n
"""
# pdb.set_trace()
range_m = start_m + tl.arange(0, BLOCK_M)
range_d = tl.arange(0, BLOCK_K)

x_ptr = X + (batch_pid * stride_batch) + (group_pid * stride_group) + (head_pid * stride_head) # pointer to x block
x_mask = (range_m < seqlen_x)[:, None] & (range_d < rotary_dim)[None, :]

ro_dim_half = rotary_dim // 2 # length of cos/sin

if SEQLEN_OFFSET_IS_TENSOR:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have 2 versions.

seqlen_offset = tl.load(SEQLEN_OFFSET + batch_pid) # a tensor
else:
seqlen_offset = SEQLEN_OFFSET # an int

# load full x (puts values in cache)
x_range = range_m[:, None]*stride_m + range_d[None, :]
x_mask = (range_m < seqlen_x)[:, None] & (range_d < head_dim)[None, :]
x = tl.load(x_ptr + x_range, mask=x_mask)


if not INTERLEAVED:
range_d_half_duplicate = range_d % (rotary_dim // 2)

x0_range = range_m[:, None]*stride_m + range_d_half_duplicate[None, :]*stride_headdim # BLOCK_M x 1st half of headdim (fast to load)
x1_range = range_m[:, None]*stride_m + range_d_half_duplicate[None, :]*stride_headdim + ro_dim_half # BLOCK_M x 2nd half of headdim (fast to load)

x0_mask = (range_m < seqlen_x)[:, None] & (range_d_half_duplicate < rotary_dim)[None, :] # Mask for the first half
x1_mask = (range_m < seqlen_x)[:, None] & (range_d_half_duplicate + ro_dim_half < rotary_dim)[None, :] # Mask for the second half

range_m_cos_sin = range_m + seqlen_offset # offsets cos and sin based on current m position range and seqlen offset
COS = COS + (range_m_cos_sin[:, None] * ro_dim_half + range_d_half_duplicate[None, :])
SIN = SIN + (range_m_cos_sin[:, None] * ro_dim_half + range_d_half_duplicate[None, :])
cos = tl.load(
COS, mask=(range_m[:, None] < seqlen_x) & (range_d_half_duplicate[None, :] < ro_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(range_m[:, None] < seqlen_x + seqlen_offset) & (range_d_half_duplicate[None, :] < ro_dim_half), other=0.0
).to(tl.float32)
if CONJUGATE:
sin = -sin

x0 = tl.load(x_ptr + x0_range, mask=x0_mask).to(tl.float32)
x1 = tl.load(x_ptr + x1_range, mask=x1_mask).to(tl.float32)

# Rotate corresponding elements in each half
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos

out = tl.where(range_d[None, :] // ro_dim_half == 0, o0, o1)

# for all dim not in rotary_dim, leave untouched
out = tl.where(range_d[None, :] < rotary_dim, out, x)

# transpose the rotated vector
if TRANSPOSE:
out = tl.trans(out)

return out

else:
# Interleaved is slow due to x1 load
range_d_swap = range_d + ((range_d + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...

# X Range
x0_range = range_m[:, None]*stride_m + range_d[None, :] # 0, 1, 2, 3, 4, 5, ... (fast to load)
x1_range = range_m[:, None]*stride_m + range_d_swap[None, :] # 1, 0, 3, 2, 5, 4, ... (slow to load)

# X Masks
x0_mask = (range_m < seqlen_x)[:, None] & (range_d < rotary_dim)[None, :] # Mask for the first half
x1_mask = (range_m < seqlen_x)[:, None] & (range_d_swap < rotary_dim)[None, :] # Mask for the second half

# Load COS/SIN
range_d_repeat = tl.arange(0, BLOCK_K) // 2 # 0, 0, 1, 1, 2, 2, ...

range_m_cos_sin = range_m + seqlen_offset
COS = COS + (range_m_cos_sin[:, None] * ro_dim_half + range_d_repeat[None, :])
SIN = SIN + (range_m_cos_sin[:, None] * ro_dim_half + range_d_repeat[None, :])
cos = tl.load(
COS,
mask=(range_m[:, None] < seqlen_x) & (range_d_repeat[None, :] < ro_dim_half),
other=1.0,
).to(tl.float32)
sin = tl.load(
SIN,
mask=(range_m[:, None] < seqlen_x) & (range_d_repeat[None, :] < ro_dim_half),
other=0.0,
).to(tl.float32)
if CONJUGATE:
sin = -sin

x0 = tl.load(x_ptr + x0_range, mask=x0_mask)
x1 = tl.load(x_ptr + x1_range, mask=x1_mask)

x0_cos = x0 * cos
x1_sin = x1 * sin

out = tl.where(range_d[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)

# for all dim not in rotary_dim, leave untouched
out = tl.where(range_d[None, :] < rotary_dim, out, x)

# transpose the rotated vector
if TRANSPOSE:
out = tl.trans(out)

return out

@triton.jit
def _fwd_kernel_splitK(
Q,
Expand All @@ -16,6 +163,15 @@ def _fwd_kernel_splitK(
Cache_seqlens,
Cache_batch_idx,
Alibi_slopes,
# Rotary
Rotary_cos,
Rotary_sin,
Rotary_dim,
Rotary_interleaved: tl.constexpr,
Rotary_conjugate: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
# Strides
stride_qz,
stride_qm,
stride_qg,
Expand Down Expand Up @@ -64,12 +220,13 @@ def _fwd_kernel_splitK(
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
BOUNDS_CHECKS_N: tl.constexpr,
USE_CACHE_SEQLENs: tl.constexpr,
USE_CACHE_SEQLENS: tl.constexpr,
USE_CACHE_BATCH_IDX: tl.constexpr,
NEW_KV: tl.constexpr,
IS_GQA: tl.constexpr,
IS_CAUSAL: tl.constexpr,
USE_ALIBI: tl.constexpr,
USE_ROTARY: tl.constexpr,
):
# Padding
PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
Expand Down Expand Up @@ -97,7 +254,7 @@ def _fwd_kernel_splitK(
alibi_slope = None

lo = splitk_idx * BLOCK_N_PER_SPLIT
if USE_CACHE_SEQLENs:
if USE_CACHE_SEQLENS:
cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z)
if NEW_KV:
kv_len = cache_seqlen_last_idx + N_CTX_NEW
Expand All @@ -124,22 +281,55 @@ def _fwd_kernel_splitK(
knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g

# Determine the starting position for new data in the cache
if USE_CACHE_SEQLENs:
if USE_CACHE_SEQLENS:
start_idx = tl.load(Cache_seqlens + off_z)
else:
start_idx = N_CTX_K - N_CTX_NEW

# Copy new Keys
for i in range(0, N_CTX_NEW, BLOCK_N):
# Load from K_new
k_new_block = tl.load(
knew_base +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
other=0
)

# Load from K_new and apply rotary to k
if USE_ROTARY:
k_new_block = rotary_kernel_splitk(
X=K_new,
seqlen_x=N_CTX_NEW,
head_dim=BLOCK_DMODEL,
rotary_dim=Rotary_dim,

COS=Rotary_cos,
SIN=Rotary_sin,
SEQLEN_OFFSET=Cache_seqlens,
SEQLEN_OFFSET_IS_TENSOR=IS_SEQLEN_OFFSETS_TENSOR,

batch_pid=off_z,
start_m=i, # current block of tokens in new_k
group_pid=off_g_q,
head_pid=k_head_idx,

stride_batch=stride_kn_z, # batch_strides if not varlen else 0
stride_m=stride_kn_n,
stride_group=stride_kn_g,
stride_head=stride_kn_h,
stride_headdim=stride_kn_d,

INTERLEAVED=Rotary_interleaved,
CONJUGATE=Rotary_conjugate,
TRANSPOSE=True,

BLOCK_M=BLOCK_N,
BLOCK_K=BLOCK_DMODEL
)
else:
# Load from K_new
k_new_block = tl.load(
knew_base +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
other=0
)

# Store to K
tl.store(
Expand Down Expand Up @@ -213,9 +403,41 @@ def _fwd_kernel_splitK(
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
# load q: decide if should apply rotary after load
if USE_ROTARY:
q = rotary_kernel_splitk(
X=Q,
seqlen_x=N_CTX_Q,
head_dim=BLOCK_DMODEL,
rotary_dim=Rotary_dim,

COS=Rotary_cos,
SIN=Rotary_sin,
SEQLEN_OFFSET=Cache_seqlens,
SEQLEN_OFFSET_IS_TENSOR=IS_SEQLEN_OFFSETS_TENSOR,

batch_pid=off_z,
start_m=start_m*BLOCK_M,
group_pid=off_g_q,
head_pid=off_h_q,

stride_batch= (stride_qz if not IS_VARLEN else 0), # batch_strides if not varlen else 0
stride_m=stride_qm,
stride_group=stride_qg,
stride_head=stride_qh,
stride_headdim=stride_qd,

INTERLEAVED=Rotary_interleaved,
CONJUGATE=Rotary_conjugate,
TRANSPOSE=False,

BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_DMODEL
)
else:
# load q: it will stay in SRAM throughout
q = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
q = (q * qk_scale).to(q.dtype)
if PADDED_HEAD:
q = tl.where(d_mask[None, :], q, 0.0)
Expand Down Expand Up @@ -339,8 +561,8 @@ def load_k_v_group(
V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id))

# -- load k, v --
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ())
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ())
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()).to(tl.float32)
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()).to(tl.float32)

return k, v

Expand Down Expand Up @@ -540,7 +762,11 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
split_k = max(split_k, 1)
return split_k

def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new):
def attention_decode_forward_triton_impl(q, k, v,
sm_scale, causal, alibi_slopes,
layout, cache_seqlens, cache_batch_idx,
new_kv, k_new, v_new,
rotary_cos, rotary_sin, rotary_dim, rotary_interleaved, rotary_conjugate):
# kernel config
BLOCK_M = 16
BLOCK_N = 64
Expand Down Expand Up @@ -620,6 +846,13 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes
Cache_seqlens=cache_seqlens,
Cache_batch_idx=cache_batch_idx,
Alibi_slopes=alibi_slopes,
Rotary_cos=rotary_cos,
Rotary_sin=rotary_sin,
Rotary_dim=rotary_dim,
Rotary_interleaved = rotary_interleaved,
Rotary_conjugate = rotary_conjugate,
IS_SEQLEN_OFFSETS_TENSOR = isinstance(cache_seqlens, torch.Tensor),
IS_VARLEN = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is IS_VARLEN manually set to false

Copy link
Author

@alexkranias-amd alexkranias-amd Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because with rotary you have the option to use varlen but decode only use batched. We don't have a varlen parameter to pass in from decode tests

**_strides(q, "qz", "qm", "qg", "qh", "qd"),
**_strides(k, "kz", "kn", "kg", "kh", "kd"),
**_strides(v, "vz", "vn", "vg", "vh", "vd"),
Expand All @@ -641,12 +874,13 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes
BLOCK_DMODEL=dim_padded,
ACTUAL_BLOCK_DMODEL=dim_k,
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens,
USE_CACHE_SEQLENs=use_cache_seqlens,
USE_CACHE_SEQLENS=use_cache_seqlens,
USE_CACHE_BATCH_IDX=cache_batch_idx is not None,
NEW_KV=new_kv,
IS_GQA=is_gqa,
IS_CAUSAL=causal,
USE_ALIBI=False if alibi_slopes is None else True,
USE_ROTARY= False if rotary_cos is None or rotary_sin is None else True,
num_warps=num_warps,
num_stages=1,
)
Expand Down
Loading
Loading