From 0c2fb252646f7bca12c645e4c79fb0515e0d99a5 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Thu, 5 Sep 2024 12:05:38 -0400 Subject: [PATCH] Fix ima for split-kv kernel (#20) --- csrc/flash_attn/flash_api.cpp | 37 +++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index fa64e7cd3..88710e02b 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -209,7 +209,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n return 1; } -void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, +std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, const int head_size_rounded, const float p_dropout, const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { @@ -221,19 +221,24 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; params.num_splits = num_splits; + at::Tensor softmax_lse_accum; + at::Tensor out_accum; + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (num_splits < 1) { // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128); } if (params.num_splits > 1) { - at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); } TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); } + + return std::make_tuple(softmax_lse_accum, out_accum); } void set_params_alibi(Flash_fwd_params ¶ms, c10::optional &alibi_slopes_, int batch_size, int num_heads){ @@ -394,10 +399,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size softcap ); - - set_params_splitkv(params, batch_size, num_heads, - head_size, seqlen_k, seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, + head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts); // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -642,11 +648,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s params.v_batch_stride = v_padded.stride(0); } params.page_block_size = page_block_size; + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; if (seqlenq_ngroups_swapped) { // Only apply split-k for decoding - set_params_splitkv(params, batch_size, num_heads, - head_size, max_seqlen_k, max_seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); + std::tie(softmax_lse_accum, out_accum) = + set_params_splitkv(params, batch_size, num_heads, head_size, + max_seqlen_k, max_seqlen_q, head_size_rounded, + p_dropout, /*num_splits*/ 0, dprops, opts); } // number of times random will be generated per thread, to offset philox counter in thc random @@ -936,9 +945,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); } - set_params_splitkv(params, batch_size, num_heads, - head_size, seqlen_k, seqlen_q, - head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts); + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, + head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts); if (paged_KV) { params.block_table = block_table.data_ptr();