From 13ba75ad32a996edb7a25ce044fa5d97c4eee2e9 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 14 Jun 2024 08:44:31 -0500 Subject: [PATCH] keep interface and kernel seperate --- flash_attn/flash_attn_interface.py | 2 +- flash_attn/flash_attn_triton_interface_amd.py | 290 ++++++++++++++++ ...amd.py => flash_attn_triton_kernel_amd.py} | 325 +----------------- 3 files changed, 304 insertions(+), 313 deletions(-) create mode 100644 flash_attn/flash_attn_triton_interface_amd.py rename flash_attn/{flash_attn_triton_amd.py => flash_attn_triton_kernel_amd.py} (88%) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 24a317f9e..272f73dbf 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -13,7 +13,7 @@ def is_hip(): # isort: off # We need to import the CUDA kernels after importing torch if is_hip(): - from . import flash_attn_triton_amd as flash_attn_gpu + from . import flash_attn_triton_interface_amd as flash_attn_gpu else: import flash_attn_2_cuda as flash_attn_gpu diff --git a/flash_attn/flash_attn_triton_interface_amd.py b/flash_attn/flash_attn_triton_interface_amd.py new file mode 100644 index 000000000..d27619ac3 --- /dev/null +++ b/flash_attn/flash_attn_triton_interface_amd.py @@ -0,0 +1,290 @@ +from .flash_attn_triton_kernel_amd import MetaData, attention, get_shape_from_layout, _attn_bwd_preprocess, _attn_bwd +import torch +import triton + +# /////////////////////////////////////////// Interface ////////////////////////////////////////////////////////// +DEBUG=False + + +def fwd(q, + k, + v, + o, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, + window_size_right, + return_softmax, + gen_): + if DEBUG: + print("flash_attn_triton_amd.py::fwd") + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("alibi_slopes:", alibi_slopes) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("return_softmax:", return_softmax) + print("gen_:", gen_) + + if dropout_p != 0.0: + raise ValueError("dropout is not supported on HIP") + + + if o is None: + o = torch.empty_like(q) + + + # Create metadata object + metadata = MetaData(sm_scale=softmax_scale) + metadata.max_seqlens_q = q.shape[1] + metadata.max_seqlens_k = k.shape[1] + metadata.layout = "bshd" + + # Setup metadata + if causal: + metadata.need_causal() + # if bias is not None: + # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) + if alibi_slopes is not None: + metadata.need_alibi(alibi_slopes, q.shape[0], q.shape[2]) + if dropout_p > 0.0: + metadata.need_dropout(dropout_p, return_softmax) + + # Check arguments + metadata.check_args(q, k, v, o) + + # Perform the forward attention computation + tri_out, encoded_softmax = attention(q, k, v, o, metadata) + + softmax_lse = encoded_softmax + softmax_p = encoded_softmax + + return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state() + +def varlen_fwd( + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes,\ + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + window_size_left, + window_size_right, + return_softmax, + gen_): + + if DEBUG: + print("flash_attn_triton_amd.py::varlen_fwd") + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + + if dropout_p != 0.0: + raise ValueError("dropout is not supported on HIP") + + + + if o is None: + o = torch.empty_like(q) + + + + # create metadata object + input_metadata = MetaData(sm_scale=softmax_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + + # get shapes + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata) + + # Setup metadata + if causal: + input_metadata.need_causal() + # if bias is not None: + # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) + if alibi_slopes is not None: + input_metadata.need_alibi(alibi_slopes, batch, nheads_q) + if dropout_p > 0.0: + input_metadata.need_dropout(dropout_p, return_softmax) + + # Check arguments + input_metadata.check_args(q, k, v, o) + + # Perform the forward attention computation + tri_out, encoded_softmax = attention(q, k, v, o, input_metadata) + + softmax_lse = encoded_softmax + softmax_p = encoded_softmax + + return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state() + +def fwd_kvcache(*args, **kwargs): + pass + + +def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, + window_size_right, deterministic, gen_, rng_state): + if DEBUG: + print("flash_attn_triton_amd.py::bwd") + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("softmax_lse:", softmax_lse) + print("dq:", dq.shape) + print("dk:", dk.shape) + print("dv:", dv.shape) + print("alibi_slopes:", alibi_slopes) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + + + if out is None: + out = torch.empty_like(q) + + + # Ensure the tensors have requires_grad=True + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + out.requires_grad_() + + # Create metadata object + metadata = MetaData(sm_scale=softmax_scale) + metadata.max_seqlens_q = q.shape[1] + metadata.max_seqlens_k = k.shape[1] + metadata.layout = "bshd" + + if metadata == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + + batch = q.shape[0] + nheads_q = q.shape[1] + BLOCK_DMODEL = q.shape[3] + + # Setup metadata + if causal: + metadata.need_causal() + # if bias is not None: + # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) + + return_softmax = True + if alibi_slopes is not None: + metadata.need_alibi(alibi_slopes, batch, nheads_q) + if dropout_p > 0.0: + metadata.need_dropout(dropout_p, return_softmax) + + # Check arguments + metadata.check_args(q, k, v, out) + + + + # tri_out, _ = attention(q, k, v, out, metadata) + # tri_out.requires_grad_() + # dout.requires_grad_() + # tri_out.backward(dout) + + # write your own version backward + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) # this passed from + + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + o = out + do = dout + sm_scale = softmax_scale + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + seqlen_q = q.shape[2] + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_CTX, N_HEAD = q.shape[:3] + PRE_BLOCK = 128 + # NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (sm_scale * RCP_LN2) + if DEBUG: + print("N_CTX:", N_CTX) + # assert N_CTX % PRE_BLOCK == 0 + + delta = torch.empty_like(M) + _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] + # padded_head = (Lk != ctx.BLOCK_DMODEL) + grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + _attn_bwd_preprocess[grid_preprocess]( + o, + do, + delta, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + do.stride(0), + do.stride(1), + do.stride(2), + do.stride(3), + seqlen_q, + head_dim=Lk, + BLOCK_M=BLOCK, + D_HEAD=BLOCK_DMODEL, + ) + grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, + arg_k, + v, + sm_scale, + alibi_slopes, + do, + dq, + dk, + dv, + M, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + N_HEAD, + N_CTX, + BLOCK_DMODEL= BLOCK_DMODEL, + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + USE_ALIBI=False if alibi_slopes is None else True, + ) + + return dq, dk, dv, None + + +def varlen_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, *args, **kwargs): + pass \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd.py b/flash_attn/flash_attn_triton_kernel_amd.py similarity index 88% rename from flash_attn/flash_attn_triton_amd.py rename to flash_attn/flash_attn_triton_kernel_amd.py index ed49c61d8..a13fc6c33 100644 --- a/flash_attn/flash_attn_triton_amd.py +++ b/flash_attn/flash_attn_triton_kernel_amd.py @@ -119,9 +119,7 @@ def check_args(self, q, k, v, o): assert q.dtype == k.dtype and q.dtype == v.dtype assert head_size <= 256 assert o.shape == q.shape - print("nheads_q", nheads_q) - print("nheads_k", nheads_k) - # assert (nheads_q % nheads_k) == 0 + assert (nheads_q % nheads_k) == 0 assert self.layout is not None assert self.layout == 'thd' or not self.varlen @@ -935,13 +933,7 @@ def forward(ctx, q, k, v, o, metadata): alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) else: alibi_strides = (0, 0) - - print("_attention::forward") - print("q:", q.shape) - print("k:", k.shape) - print("v:", v.shape) - print("metadata:", metadata) attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, metadata.cu_seqlens_q, metadata.cu_seqlens_k, dropout_p=metadata.dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, @@ -952,24 +944,21 @@ def forward(ctx, q, k, v, o, metadata): USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) - if ctx: - ctx.save_for_backward(q, k, v, o, M) - ctx.grid = grid - ctx.sm_scale = metadata.sm_scale - ctx.BLOCK_DMODEL = head_size - ctx.causal = metadata.causal - ctx.alibi_slopes = metadata.alibi_slopes - ctx.dropout_p = metadata.dropout_p - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = metadata.return_encoded_softmax + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = metadata.sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes + ctx.dropout_p = metadata.dropout_p + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = metadata.return_encoded_softmax return o, encoded_softmax @staticmethod def backward(ctx, do, _): - print("_attention::backward") - print("do:", do.shape) if torch.version.hip is not None: BLOCK = 64 else: @@ -1512,295 +1501,7 @@ def supported_layouts(): 'This layout is sometimes called "varlen" or "grouped" layout.' return layouts -# /////////////////////////////////////////// Interface ////////////////////////////////////////////////////////// -DEBUG=False - - -def fwd(q, - k, - v, - o, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - return_softmax, - gen_): - if DEBUG: - print("flash_attn_triton_amd.py::fwd") - print("q:", q.shape) - print("k:", k.shape) - print("v:", v.shape) - print("alibi_slopes:", alibi_slopes) - print("dropout_p:", dropout_p) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("return_softmax:", return_softmax) - print("gen_:", gen_) - - if dropout_p != 0.0: - raise ValueError("dropout is not supported on HIP") - - - if o is None: - o = torch.empty_like(q) - - - # Create metadata object - metadata = MetaData(sm_scale=softmax_scale) - metadata.max_seqlens_q = q.shape[1] - metadata.max_seqlens_k = k.shape[1] - metadata.layout = "bshd" - - # Setup metadata - if causal: - metadata.need_causal() - # if bias is not None: - # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) - if alibi_slopes is not None: - metadata.need_alibi(alibi_slopes, q.shape[0], q.shape[2]) - if dropout_p > 0.0: - metadata.need_dropout(dropout_p, return_softmax) - - # Check arguments - metadata.check_args(q, k, v, o) - - # Perform the forward attention computation - tri_out, encoded_softmax = attention(q, k, v, o, metadata) - - softmax_lse = encoded_softmax - softmax_p = encoded_softmax - - return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state() - -def varlen_fwd( - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - block_table_, - alibi_slopes,\ - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - return_softmax, - gen_): - - print("flash_attn_triton_amd.py::varlen_fwd") - print("q:", q.shape) - print("k:", k.shape) - print("v:", v.shape) - - if dropout_p != 0.0: - raise ValueError("dropout is not supported on HIP") - - - - if o is None: - o = torch.empty_like(q) - - - - # create metadata object - input_metadata = MetaData(sm_scale=softmax_scale) - input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) - - # get shapes - batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata) - - # Setup metadata - if causal: - input_metadata.need_causal() - # if bias is not None: - # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) - if alibi_slopes is not None: - input_metadata.need_alibi(alibi_slopes, batch, nheads_q) - if dropout_p > 0.0: - input_metadata.need_dropout(dropout_p, return_softmax) - - # Check arguments - input_metadata.check_args(q, k, v, o) - - # Perform the forward attention computation - tri_out, encoded_softmax = attention(q, k, v, o, input_metadata) - - softmax_lse = encoded_softmax - softmax_p = encoded_softmax - - return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state() - -def fwd_kvcache(*args, **kwargs): - pass - - -def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, - window_size_right, deterministic, gen_, rng_state): - if DEBUG: - print("flash_attn_triton_amd.py::bwd") - print("q:", q.shape) - print("k:", k.shape) - print("v:", v.shape) - print("softmax_lse:", softmax_lse) - print("dq:", dq.shape) - print("dk:", dk.shape) - print("dv:", dv.shape) - print("alibi_slopes:", alibi_slopes) - print("dropout_p:", dropout_p) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("deterministic:", deterministic) - print("gen_:", gen_) - print("rng_state:", rng_state) - - - - if out is None: - out = torch.empty_like(q) - - - # Ensure the tensors have requires_grad=True - q.requires_grad_() - k.requires_grad_() - v.requires_grad_() - out.requires_grad_() - - # Create metadata object - metadata = MetaData(sm_scale=softmax_scale) - metadata.max_seqlens_q = q.shape[1] - metadata.max_seqlens_k = k.shape[1] - metadata.layout = "bshd" - - if metadata == 'bshd': - q = q.transpose(1, 2).clone() - k = k.transpose(1, 2).clone() - v = v.transpose(1, 2).clone() - - batch = q.shape[0] - nheads_q = q.shape[1] - BLOCK_DMODEL = q.shape[3] - - # Setup metadata - if causal: - metadata.need_causal() - # if bias is not None: - # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) - - return_softmax = True - if alibi_slopes is not None: - metadata.need_alibi(alibi_slopes, batch, nheads_q) - if dropout_p > 0.0: - metadata.need_dropout(dropout_p, return_softmax) - - # Check arguments - metadata.check_args(q, k, v, out) - - - - # tri_out, _ = attention(q, k, v, out, metadata) - # tri_out.requires_grad_() - # dout.requires_grad_() - # tri_out.backward(dout) - - # write your own version backward - M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) # this passed from - - if torch.version.hip is not None: - BLOCK = 64 - else: - BLOCK = 128 - o = out - do = dout - sm_scale = softmax_scale - assert do.is_contiguous() - assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() - seqlen_q = q.shape[2] - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - BATCH, N_CTX, N_HEAD = q.shape[:3] - PRE_BLOCK = 128 - # NUM_WARPS, NUM_STAGES = 4, 1 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 - BLK_SLICE_FACTOR = 2 - RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) - arg_k = k - arg_k = arg_k * (sm_scale * RCP_LN2) - print("N_CTX:", N_CTX) - # assert N_CTX % PRE_BLOCK == 0 - - delta = torch.empty_like(M) - _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] - # padded_head = (Lk != ctx.BLOCK_DMODEL) - grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) - _attn_bwd_preprocess[grid_preprocess]( - o, - do, - delta, - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - do.stride(0), - do.stride(1), - do.stride(2), - do.stride(3), - seqlen_q, - head_dim=Lk, - BLOCK_M=BLOCK, - D_HEAD=BLOCK_DMODEL, - ) - grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, - arg_k, - v, - sm_scale, - alibi_slopes, - do, - dq, - dk, - dv, - M, - delta, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - N_HEAD, - N_CTX, - BLOCK_DMODEL= BLOCK_DMODEL, - BLOCK_M1=BLOCK_M1, - BLOCK_N1=BLOCK_N1, - BLOCK_M2=BLOCK_M2, - BLOCK_N2=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - USE_ALIBI=False if alibi_slopes is None else True, - ) - - return dq, dk, dv, None - - -def varlen_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, *args, **kwargs): - pass - - - -# /////////////////////////////////////////// CLI ////////////////////////////////////////////////////////// def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark FlashAttention", @@ -1844,4 +1545,4 @@ def main(): if __name__ == '__main__': - sys.exit(main()) + sys.exit(main()) \ No newline at end of file