diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 11b7a4db0..464abd235 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 11b7a4db005dc38e60b1ea045d03a92d2a8f9cd0 +Subproject commit 464abd235e27c33422aa52ed2044af8fbcc3a88d diff --git a/csrc/flash_attn_ck/flash_common.hpp b/csrc/flash_attn_ck/flash_common.hpp index 1c7c2f062..cc86546ea 100644 --- a/csrc/flash_attn_ck/flash_common.hpp +++ b/csrc/flash_attn_ck/flash_common.hpp @@ -22,17 +22,17 @@ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") namespace flash { -// Copy from PyTorch -// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 -inline std::tuple unpack(at::PhiloxCudaState arg) { - if (arg.captured_) { - // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". - // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. - // For most threads' reads it will hit in cache, so it shouldn't hurt performance. - return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); - } else { - return std::make_tuple(arg.seed_.val, arg.offset_.val); - } +inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state) +{ + // Imitate from PyTorch + // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 + if (arg.captured_) { + rng_state[0] = static_cast(*arg.seed_.ptr); + rng_state[1] = static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_); + } else { + rng_state[0] = arg.seed_.val; + rng_state[1] = arg.offset_.val; + } } inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 1859137f8..e4a4b2a6b 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -49,8 +49,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, at::Tensor dv, float softmax_scale, float p_dropout, - uint64_t drop_seed, - uint64_t drop_offset) + std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_q = q.stride(0); @@ -191,7 +190,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, static_cast(mask.type), p_dropout, p_undrop, - {drop_seed, drop_offset}}; + drop_seed_offset}; } std::vector @@ -213,7 +212,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num const float /*softcap*/, const bool deterministic, c10::optional gen_, - c10::optional &rng_state) + c10::optional &rng_state_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -337,21 +336,24 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); - uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + at::Tensor rng_state; - if (rng_state.has_value()) { - uint64_t* d = reinterpret_cast(rng_state.value().data_ptr()); - drop_seed = d[0]; - drop_offset = d[1]; + if (rng_state_.has_value()) { + rng_state = rng_state_.value(); } else if(is_dropout) { + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); - std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, + philox_args, reinterpret_cast(rng_state.data_ptr())); } if (seqlen_q > 0) { + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; auto traits = @@ -380,8 +382,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num dv_expanded, softmax_scale, p_dropout, - drop_seed, - drop_offset); + drop_seed_offset); float t = fmha_bwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index a6b33b4ab..7202cf2c8 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -46,8 +46,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, - uint64_t drop_seed, - uint64_t drop_offset) + std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, d) // k: (batch_size, seqlen_k, nheads_k, d) @@ -137,7 +136,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, static_cast(mask.type), p_dropout, has_dropout_randval, - {drop_seed, drop_offset}}; + drop_seed_offset}; } std::vector @@ -255,10 +254,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num p = torch::empty({ 0 }, opts); } - uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( @@ -266,13 +264,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); - std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr); } - rng_state[0] = *(reinterpret_cast(&drop_seed)); - rng_state[1] = *(reinterpret_cast(&drop_offset)); - if (seqlen_k > 0) { + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; @@ -305,8 +302,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num p, softmax_scale, p_dropout, - drop_seed, - drop_offset); + drop_seed_offset); float t = fmha_fwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 531d735ed..2e5dd7b51 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -51,8 +51,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, at::Tensor dv, float softmax_scale, float p_dropout, - uint64_t drop_seed, - uint64_t drop_offset) + std::pair drop_seed_offset) { ck_tile::index_t total_q = q.size(0); ck_tile::index_t total_k = k.size(0); @@ -197,7 +196,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, static_cast(mask.type), p_dropout, p_undrop, - {drop_seed, drop_offset}}; + drop_seed_offset}; } std::vector @@ -224,7 +223,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads const float /*softcap*/, const bool deterministic, c10::optional gen_, - c10::optional &rng_state) + c10::optional &rng_state_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -362,21 +361,26 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); - uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + at::Tensor rng_state; - if (rng_state.has_value()) { - uint64_t* d = reinterpret_cast(rng_state.value().data_ptr()); - drop_seed = d[0]; - drop_offset = d[1]; + if (rng_state_.has_value()) { + rng_state = rng_state_.value(); } else if(is_dropout) { + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); - std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, + philox_args, reinterpret_cast(rng_state.data_ptr())); + } else { + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (max_seqlen_q > 0) { + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; auto traits = @@ -407,8 +411,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads dv_expanded, softmax_scale, p_dropout, - drop_seed, - drop_offset); + drop_seed_offset); float t = fmha_bwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 6e30aa74a..7e8a347d4 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -47,8 +47,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, - uint64_t drop_seed, - uint64_t drop_offset) + std::pair drop_seed_offset) { // q: (total_q, nheads, d) // k: (total_k, nheads_k, d) @@ -140,7 +139,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, static_cast(mask.type), p_dropout, has_dropout_randval, - {drop_seed, drop_offset}}; + drop_seed_offset}; } std::vector @@ -281,10 +280,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si if (return_dropout_randval) {p.zero_();} } - uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( @@ -292,13 +290,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); - std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr); } - rng_state[0] = *(reinterpret_cast(&drop_seed)); - rng_state[1] = *(reinterpret_cast(&drop_offset)); - if (max_seqlen_k > 0) { + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; @@ -332,8 +329,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si p, softmax_scale, p_dropout, - drop_seed, - drop_offset); + drop_seed_offset); float t = fmha_fwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");