Skip to content

Commit

Permalink
enable flash_attn_with_kvcache (#68)
Browse files Browse the repository at this point in the history
* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat
  • Loading branch information
micmelesse authored Aug 6, 2024
1 parent 508a92a commit 01a1329
Show file tree
Hide file tree
Showing 7 changed files with 1,535 additions and 272 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ jobs:
- name: Build
run: |
python setup.py install
- name: Test
- name: AMD Kernel Tests
run: |
pytest flash_attn/flash_attn_triton_kernel_decode_amd.py::test_op_fwd
pytest flash_attn/flash_attn_triton_kernel_prefill_amd.py
- name: Flash Attention Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_kvcache
pytest tests/test_flash_attn.py::test_flash_attn_output
pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ var/
# Dev
venv

# Other
# AMD
.eggs
.vscode
core
scripts
log*
*csv
321 changes: 207 additions & 114 deletions benchmarks/benchmark_flash_attention.py

Large diffs are not rendered by default.

196 changes: 64 additions & 132 deletions flash_attn/flash_attn_triton_interface_amd.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .flash_attn_triton_kernel_amd import MetaData, attention, get_shape_from_layout, _attn_bwd_preprocess, _attn_bwd
import torch
import triton
from .flash_attn_triton_kernel_prefill_amd import MetaData, attention_prefill, get_shape_from_layout, _attn_bwd_preprocess, _attn_bwd
from .flash_attn_triton_kernel_decode_amd import attention_decode

DEBUG=False
DEBUG = False

def fwd(q,
k,
Expand Down Expand Up @@ -31,7 +32,7 @@ def fwd(q,
print("gen_:", gen_)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on HIP")
raise ValueError("dropout is not supported on AMD yet")

if o is None:
o = torch.empty_like(q)
Expand Down Expand Up @@ -60,7 +61,7 @@ def fwd(q,
input_metadata.check_args(q, k, v, o)

# Perform the forward attention computation
tri_out, encoded_softmax = attention(q, k, v, o, input_metadata)
tri_out, encoded_softmax = attention_prefill(q, k, v, o, input_metadata)

softmax_lse = encoded_softmax
softmax_p = encoded_softmax
Expand Down Expand Up @@ -93,9 +94,23 @@ def varlen_fwd(
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)
print("cu_seqlens_q:", cu_seqlens_q)
print("cu_seqlens_k:", cu_seqlens_k)
print("block_table_:", block_table_)
print("alibi_slopes:", alibi_slopes)
print("max_seqlen_q:", max_seqlen_q)
print("max_seqlen_k:", max_seqlen_k)
print("dropout_p:", dropout_p)
print("softmax_scale:", softmax_scale)
print("zero_tensors:", zero_tensors)
print("causal:", causal)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("return_softmax:", return_softmax)
print("gen_:", gen_)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on HIP")
raise ValueError("dropout is not supported on AMD yet")

if o is None:
o = torch.empty_like(q)
Expand Down Expand Up @@ -123,14 +138,14 @@ def varlen_fwd(
input_metadata.check_args(q, k, v, o)

# Perform the forward attention computation
tri_out, encoded_softmax = attention(q, k, v, o, input_metadata)
tri_out, encoded_softmax = attention_prefill(q, k, v, o, input_metadata)

softmax_lse = encoded_softmax
softmax_p = encoded_softmax

return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state()

def fwd_kvcache(
def fwd_kvcache(
q,
k_cache,
v_cache,
Expand All @@ -149,150 +164,67 @@ def fwd_kvcache(
window_size_right,
rotary_interleaved,
num_splits):
pass


def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left,
window_size_right, deterministic, gen_, rng_state):

if DEBUG:
print("flash_attn_triton_amd.py::bwd")
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)
print("softmax_lse:", softmax_lse)
print("dq:", dq.shape)
print("dk:", dk.shape)
print("dv:", dv.shape)
print()
print("flash_attn_triton_amd.py::fwd_kvcache")
print("q:", q, q.shape)
print("k_cache:", k_cache, k_cache.shape)
print("v_cache:", v_cache, v_cache.shape)
print("k:", k, k.shape if k is not None else None)
print("v:", v, v.shape if v is not None else None)
print("cache_seqlens:", cache_seqlens, cache_seqlens.size())
print("rotary_cos:", rotary_cos)
print("rotary_sin:", rotary_sin)
print("cache_batch_idx:", cache_batch_idx)
print("block_table:", block_table, block_table.shape if block_table is not None else None)
print("alibi_slopes:", alibi_slopes)
print("dropout_p:", dropout_p)
print("out:", out)
print("softmax_scale:", softmax_scale)
print("causal:", causal)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("deterministic:", deterministic)
print("gen_:", gen_)
print("rng_state:", rng_state)

print("rotary_interleaved:", rotary_interleaved)
print("num_splits:", num_splits)

if out is None:
out = torch.empty_like(q)

# Ensure the tensors have requires_grad=True
q.requires_grad_()
k.requires_grad_()
v.requires_grad_()
out.requires_grad_()

# Create metadata object
metadata = MetaData(sm_scale=softmax_scale)
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k.shape[1]
metadata.layout = "bshd"
# fill metadata
input_metadata = MetaData(sm_scale=softmax_scale)
input_metadata.layout = "bshd"
input_metadata.max_seqlens_q = q.shape[1]
input_metadata.max_seqlens_k = k_cache.shape[1]
input_metadata.cache_seqlens = cache_seqlens
input_metadata.cache_batch_idx = cache_batch_idx

if metadata == 'bshd':
q = q.transpose(1, 2).clone()
k = k.transpose(1, 2).clone()
v = v.transpose(1, 2).clone()
if k is not None and v is not None:
input_metadata.new_kv = True
input_metadata.seqlen_new = k.shape[1]
input_metadata.k_new = k
input_metadata.v_new = v

batch = q.shape[0]
nheads_q = q.shape[1]
BLOCK_DMODEL = q.shape[3]

# Setup metadata
if causal:
metadata.need_causal()
input_metadata.need_causal()

# if bias is not None:
# metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2])

return_softmax = True
if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, batch, nheads_q)

if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
batch, _ , nheads_q, _= q.shape
input_metadata.need_alibi(alibi_slopes, batch, nheads_q)

# Check arguments
metadata.check_args(q, k, v, out)

# write your own version backward
M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) # this passed from
# launch kernel
tri_out = attention_decode(q, k_cache, v_cache, input_metadata)

if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
o = out
do = dout
sm_scale = softmax_scale
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
seqlen_q = q.shape[2]
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_CTX, N_HEAD = q.shape[:3]
PRE_BLOCK = 128
# NUM_WARPS, NUM_STAGES = 4, 1
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (sm_scale * RCP_LN2)
if DEBUG:
print("N_CTX:", N_CTX)
# assert N_CTX % PRE_BLOCK == 0
print()
print("tri_out:", tri_out, tri_out.shape)

return tri_out, None

delta = torch.empty_like(M)
_, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
# padded_head = (Lk != ctx.BLOCK_DMODEL)
grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0])
_attn_bwd_preprocess[grid_preprocess](
o,
do,
delta,
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
do.stride(0),
do.stride(1),
do.stride(2),
do.stride(3),
seqlen_q,
head_dim=Lk,
BLOCK_M=BLOCK,
D_HEAD=BLOCK_DMODEL,
)
grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD)
_attn_bwd[grid](
q,
arg_k,
v,
sm_scale,
alibi_slopes,
do,
dq,
dk,
dv,
M,
delta,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
N_HEAD,
N_CTX,
BLOCK_DMODEL= BLOCK_DMODEL,
BLOCK_M1=BLOCK_M1,
BLOCK_N1=BLOCK_N1,
BLOCK_M2=BLOCK_M2,
BLOCK_N2=BLOCK_N2,
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,
USE_ALIBI=False if alibi_slopes is None else True,
)

return dq, dk, dv, None
def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left,
window_size_right, deterministic, gen_, rng_state):
raise ValueError("bwd is not supported on AMD yet")


def varlen_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, *args, **kwargs):
pass
raise ValueError("varlen_bwd is not supported on AMD yet")
Loading

0 comments on commit 01a1329

Please sign in to comment.