-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Support for Rotary Positional Embeddings (both non-fused and fused kernel) #99
base: main_perf
Are you sure you want to change the base?
Changes from 5 commits
dc1271a
e02ceee
6ead05a
7ed3dc1
bdf5c42
9394dc8
1bf5ff1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,153 @@ | |
import triton.language as tl | ||
from .utils import _strides, get_padded_headsize | ||
|
||
@triton.jit | ||
def rotary_kernel_splitk( | ||
# Dimensions of X | ||
X, # tensor being rotated. Has shape (batch (z), seqlen (s), group (g), head (h), head_dim (d)) | ||
seqlen_x, # seqlen of the x dim. shape is (batch (z), ) | ||
head_dim, | ||
rotary_dim, # size of embedding space we end up rotating | ||
|
||
# COS/SIN and Offsetting Into It | ||
COS, # tensor of shape (seqlen (m), ro_dim // 2) | ||
SIN, # tensor of shape (seqlen (m), ro_dim // 2) | ||
SEQLEN_OFFSET, # we use this as an offset into COS and SIN to apply the correct rotation | ||
SEQLEN_OFFSET_IS_TENSOR: tl.constexpr, # if seqlen_offset is a tensor it has shape (num_batch, ) | ||
|
||
# PID Offsets | ||
batch_pid: tl.constexpr, # pid for batch | ||
start_m: tl.constexpr, # the token idx the current M_BLOCK starts at. | ||
group_pid: tl.constexpr, # pid for group | ||
head_pid: tl.constexpr, # pid to access head | ||
|
||
# Strides | ||
stride_batch: tl.constexpr, | ||
stride_m: tl.constexpr, | ||
stride_group: tl.constexpr, | ||
stride_head: tl.constexpr, | ||
stride_headdim: tl.constexpr, | ||
|
||
# Misc | ||
INTERLEAVED: tl.constexpr, | ||
CONJUGATE: tl.constexpr, | ||
TRANSPOSE: tl.constexpr, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should do the transpose inside the rotary function.. It probably makes sense for the caller to do that if it makes sense for them |
||
|
||
# Meta-parameters | ||
BLOCK_M: tl.constexpr, # block size to access chunks of tokens (# of tokens simultaneously) | ||
BLOCK_K: tl.constexpr, # block size to access chunks of headdim (# of dimensions processed) | ||
): | ||
""" | ||
Note: | ||
- for K in splitk let BLOCK_M = BLOCK_N, and start_m=start_n | ||
""" | ||
# pdb.set_trace() | ||
range_m = start_m + tl.arange(0, BLOCK_M) | ||
range_d = tl.arange(0, BLOCK_K) | ||
|
||
x_ptr = X + (batch_pid * stride_batch) + (group_pid * stride_group) + (head_pid * stride_head) # pointer to x block | ||
x_mask = (range_m < seqlen_x)[:, None] & (range_d < rotary_dim)[None, :] | ||
|
||
ro_dim_half = rotary_dim // 2 # length of cos/sin | ||
|
||
if SEQLEN_OFFSET_IS_TENSOR: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we have 2 versions. |
||
seqlen_offset = tl.load(SEQLEN_OFFSET + batch_pid) # a tensor | ||
else: | ||
seqlen_offset = SEQLEN_OFFSET # an int | ||
|
||
# load full x (puts values in cache) | ||
x_range = range_m[:, None]*stride_m + range_d[None, :] | ||
x_mask = (range_m < seqlen_x)[:, None] & (range_d < head_dim)[None, :] | ||
x = tl.load(x_ptr + x_range, mask=x_mask) | ||
|
||
|
||
if not INTERLEAVED: | ||
range_d_half_duplicate = range_d % (rotary_dim // 2) | ||
|
||
x0_range = range_m[:, None]*stride_m + range_d_half_duplicate[None, :]*stride_headdim # BLOCK_M x 1st half of headdim (fast to load) | ||
x1_range = range_m[:, None]*stride_m + range_d_half_duplicate[None, :]*stride_headdim + ro_dim_half # BLOCK_M x 2nd half of headdim (fast to load) | ||
|
||
x0_mask = (range_m < seqlen_x)[:, None] & (range_d_half_duplicate < rotary_dim)[None, :] # Mask for the first half | ||
x1_mask = (range_m < seqlen_x)[:, None] & (range_d_half_duplicate + ro_dim_half < rotary_dim)[None, :] # Mask for the second half | ||
|
||
range_m_cos_sin = range_m + seqlen_offset # offsets cos and sin based on current m position range and seqlen offset | ||
COS = COS + (range_m_cos_sin[:, None] * ro_dim_half + range_d_half_duplicate[None, :]) | ||
SIN = SIN + (range_m_cos_sin[:, None] * ro_dim_half + range_d_half_duplicate[None, :]) | ||
cos = tl.load( | ||
COS, mask=(range_m[:, None] < seqlen_x) & (range_d_half_duplicate[None, :] < ro_dim_half), other=1.0 | ||
).to(tl.float32) | ||
sin = tl.load( | ||
SIN, mask=(range_m[:, None] < seqlen_x + seqlen_offset) & (range_d_half_duplicate[None, :] < ro_dim_half), other=0.0 | ||
).to(tl.float32) | ||
if CONJUGATE: | ||
sin = -sin | ||
|
||
x0 = tl.load(x_ptr + x0_range, mask=x0_mask).to(tl.float32) | ||
x1 = tl.load(x_ptr + x1_range, mask=x1_mask).to(tl.float32) | ||
|
||
# Rotate corresponding elements in each half | ||
o0 = x0 * cos - x1 * sin | ||
o1 = x0 * sin + x1 * cos | ||
|
||
out = tl.where(range_d[None, :] // ro_dim_half == 0, o0, o1) | ||
|
||
# for all dim not in rotary_dim, leave untouched | ||
out = tl.where(range_d[None, :] < rotary_dim, out, x) | ||
|
||
# transpose the rotated vector | ||
if TRANSPOSE: | ||
out = tl.trans(out) | ||
|
||
return out | ||
|
||
else: | ||
# Interleaved is slow due to x1 load | ||
range_d_swap = range_d + ((range_d + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... | ||
|
||
# X Range | ||
x0_range = range_m[:, None]*stride_m + range_d[None, :] # 0, 1, 2, 3, 4, 5, ... (fast to load) | ||
x1_range = range_m[:, None]*stride_m + range_d_swap[None, :] # 1, 0, 3, 2, 5, 4, ... (slow to load) | ||
|
||
# X Masks | ||
x0_mask = (range_m < seqlen_x)[:, None] & (range_d < rotary_dim)[None, :] # Mask for the first half | ||
x1_mask = (range_m < seqlen_x)[:, None] & (range_d_swap < rotary_dim)[None, :] # Mask for the second half | ||
|
||
# Load COS/SIN | ||
range_d_repeat = tl.arange(0, BLOCK_K) // 2 # 0, 0, 1, 1, 2, 2, ... | ||
|
||
range_m_cos_sin = range_m + seqlen_offset | ||
COS = COS + (range_m_cos_sin[:, None] * ro_dim_half + range_d_repeat[None, :]) | ||
SIN = SIN + (range_m_cos_sin[:, None] * ro_dim_half + range_d_repeat[None, :]) | ||
cos = tl.load( | ||
COS, | ||
mask=(range_m[:, None] < seqlen_x) & (range_d_repeat[None, :] < ro_dim_half), | ||
other=1.0, | ||
).to(tl.float32) | ||
sin = tl.load( | ||
SIN, | ||
mask=(range_m[:, None] < seqlen_x) & (range_d_repeat[None, :] < ro_dim_half), | ||
other=0.0, | ||
).to(tl.float32) | ||
if CONJUGATE: | ||
sin = -sin | ||
|
||
x0 = tl.load(x_ptr + x0_range, mask=x0_mask) | ||
x1 = tl.load(x_ptr + x1_range, mask=x1_mask) | ||
|
||
x0_cos = x0 * cos | ||
x1_sin = x1 * sin | ||
|
||
out = tl.where(range_d[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) | ||
|
||
# for all dim not in rotary_dim, leave untouched | ||
out = tl.where(range_d[None, :] < rotary_dim, out, x) | ||
|
||
# transpose the rotated vector | ||
if TRANSPOSE: | ||
out = tl.trans(out) | ||
|
||
return out | ||
|
||
@triton.jit | ||
def _fwd_kernel_splitK( | ||
Q, | ||
|
@@ -16,6 +163,15 @@ def _fwd_kernel_splitK( | |
Cache_seqlens, | ||
Cache_batch_idx, | ||
Alibi_slopes, | ||
# Rotary | ||
Rotary_cos, | ||
Rotary_sin, | ||
Rotary_dim, | ||
Rotary_interleaved: tl.constexpr, | ||
Rotary_conjugate: tl.constexpr, | ||
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, | ||
IS_VARLEN: tl.constexpr, | ||
# Strides | ||
stride_qz, | ||
stride_qm, | ||
stride_qg, | ||
|
@@ -64,12 +220,13 @@ def _fwd_kernel_splitK( | |
ACTUAL_BLOCK_DMODEL: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
BOUNDS_CHECKS_N: tl.constexpr, | ||
USE_CACHE_SEQLENs: tl.constexpr, | ||
USE_CACHE_SEQLENS: tl.constexpr, | ||
USE_CACHE_BATCH_IDX: tl.constexpr, | ||
NEW_KV: tl.constexpr, | ||
IS_GQA: tl.constexpr, | ||
IS_CAUSAL: tl.constexpr, | ||
USE_ALIBI: tl.constexpr, | ||
USE_ROTARY: tl.constexpr, | ||
): | ||
# Padding | ||
PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) | ||
|
@@ -97,7 +254,7 @@ def _fwd_kernel_splitK( | |
alibi_slope = None | ||
|
||
lo = splitk_idx * BLOCK_N_PER_SPLIT | ||
if USE_CACHE_SEQLENs: | ||
if USE_CACHE_SEQLENS: | ||
cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) | ||
if NEW_KV: | ||
kv_len = cache_seqlen_last_idx + N_CTX_NEW | ||
|
@@ -124,22 +281,55 @@ def _fwd_kernel_splitK( | |
knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g | ||
|
||
# Determine the starting position for new data in the cache | ||
if USE_CACHE_SEQLENs: | ||
if USE_CACHE_SEQLENS: | ||
start_idx = tl.load(Cache_seqlens + off_z) | ||
else: | ||
start_idx = N_CTX_K - N_CTX_NEW | ||
|
||
# Copy new Keys | ||
for i in range(0, N_CTX_NEW, BLOCK_N): | ||
# Load from K_new | ||
k_new_block = tl.load( | ||
knew_base + | ||
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + | ||
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, | ||
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & | ||
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), | ||
other=0 | ||
) | ||
|
||
# Load from K_new and apply rotary to k | ||
if USE_ROTARY: | ||
k_new_block = rotary_kernel_splitk( | ||
X=K_new, | ||
seqlen_x=N_CTX_NEW, | ||
head_dim=BLOCK_DMODEL, | ||
rotary_dim=Rotary_dim, | ||
|
||
COS=Rotary_cos, | ||
SIN=Rotary_sin, | ||
SEQLEN_OFFSET=Cache_seqlens, | ||
SEQLEN_OFFSET_IS_TENSOR=IS_SEQLEN_OFFSETS_TENSOR, | ||
|
||
batch_pid=off_z, | ||
start_m=i, # current block of tokens in new_k | ||
group_pid=off_g_q, | ||
head_pid=k_head_idx, | ||
|
||
stride_batch=stride_kn_z, # batch_strides if not varlen else 0 | ||
stride_m=stride_kn_n, | ||
stride_group=stride_kn_g, | ||
stride_head=stride_kn_h, | ||
stride_headdim=stride_kn_d, | ||
|
||
INTERLEAVED=Rotary_interleaved, | ||
CONJUGATE=Rotary_conjugate, | ||
TRANSPOSE=True, | ||
|
||
BLOCK_M=BLOCK_N, | ||
BLOCK_K=BLOCK_DMODEL | ||
) | ||
else: | ||
# Load from K_new | ||
k_new_block = tl.load( | ||
knew_base + | ||
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + | ||
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, | ||
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & | ||
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), | ||
other=0 | ||
) | ||
|
||
# Store to K | ||
tl.store( | ||
|
@@ -213,9 +403,41 @@ def _fwd_kernel_splitK( | |
# 2^x instead of exp in the loop because CSE and LICM | ||
# don't work as expected with `exp` in the loop | ||
qk_scale = sm_scale * 1.44269504 | ||
# load q: it will stay in SRAM throughout | ||
q = tl.load( # noqa: F821 | ||
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) | ||
# load q: decide if should apply rotary after load | ||
if USE_ROTARY: | ||
q = rotary_kernel_splitk( | ||
X=Q, | ||
seqlen_x=N_CTX_Q, | ||
head_dim=BLOCK_DMODEL, | ||
rotary_dim=Rotary_dim, | ||
|
||
COS=Rotary_cos, | ||
SIN=Rotary_sin, | ||
SEQLEN_OFFSET=Cache_seqlens, | ||
SEQLEN_OFFSET_IS_TENSOR=IS_SEQLEN_OFFSETS_TENSOR, | ||
|
||
batch_pid=off_z, | ||
start_m=start_m*BLOCK_M, | ||
group_pid=off_g_q, | ||
head_pid=off_h_q, | ||
|
||
stride_batch= (stride_qz if not IS_VARLEN else 0), # batch_strides if not varlen else 0 | ||
stride_m=stride_qm, | ||
stride_group=stride_qg, | ||
stride_head=stride_qh, | ||
stride_headdim=stride_qd, | ||
|
||
INTERLEAVED=Rotary_interleaved, | ||
CONJUGATE=Rotary_conjugate, | ||
TRANSPOSE=False, | ||
|
||
BLOCK_M=BLOCK_M, | ||
BLOCK_K=BLOCK_DMODEL | ||
) | ||
else: | ||
# load q: it will stay in SRAM throughout | ||
q = tl.load( # noqa: F821 | ||
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) | ||
q = (q * qk_scale).to(q.dtype) | ||
if PADDED_HEAD: | ||
q = tl.where(d_mask[None, :], q, 0.0) | ||
|
@@ -339,8 +561,8 @@ def load_k_v_group( | |
V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) | ||
|
||
# -- load k, v -- | ||
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) | ||
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) | ||
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()).to(tl.float32) | ||
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()).to(tl.float32) | ||
|
||
return k, v | ||
|
||
|
@@ -540,7 +762,11 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int: | |
split_k = max(split_k, 1) | ||
return split_k | ||
|
||
def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new): | ||
def attention_decode_forward_triton_impl(q, k, v, | ||
sm_scale, causal, alibi_slopes, | ||
layout, cache_seqlens, cache_batch_idx, | ||
new_kv, k_new, v_new, | ||
rotary_cos, rotary_sin, rotary_dim, rotary_interleaved, rotary_conjugate): | ||
# kernel config | ||
BLOCK_M = 16 | ||
BLOCK_N = 64 | ||
|
@@ -620,6 +846,13 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes | |
Cache_seqlens=cache_seqlens, | ||
Cache_batch_idx=cache_batch_idx, | ||
Alibi_slopes=alibi_slopes, | ||
Rotary_cos=rotary_cos, | ||
Rotary_sin=rotary_sin, | ||
Rotary_dim=rotary_dim, | ||
Rotary_interleaved = rotary_interleaved, | ||
Rotary_conjugate = rotary_conjugate, | ||
IS_SEQLEN_OFFSETS_TENSOR = isinstance(cache_seqlens, torch.Tensor), | ||
IS_VARLEN = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because with rotary you have the option to use varlen but decode only use batched. We don't have a varlen parameter to pass in from decode tests |
||
**_strides(q, "qz", "qm", "qg", "qh", "qd"), | ||
**_strides(k, "kz", "kn", "kg", "kh", "kd"), | ||
**_strides(v, "vz", "vn", "vg", "vh", "vd"), | ||
|
@@ -641,12 +874,13 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes | |
BLOCK_DMODEL=dim_padded, | ||
ACTUAL_BLOCK_DMODEL=dim_k, | ||
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, | ||
USE_CACHE_SEQLENs=use_cache_seqlens, | ||
USE_CACHE_SEQLENS=use_cache_seqlens, | ||
USE_CACHE_BATCH_IDX=cache_batch_idx is not None, | ||
NEW_KV=new_kv, | ||
IS_GQA=is_gqa, | ||
IS_CAUSAL=causal, | ||
USE_ALIBI=False if alibi_slopes is None else True, | ||
USE_ROTARY= False if rotary_cos is None or rotary_sin is None else True, | ||
num_warps=num_warps, | ||
num_stages=1, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do have to versions of SEQLEN_OFFSET? It seems an int and tensor which we do a load