Skip to content

Commit

Permalink
Merge pull request #85 from ROCm/ck_tile/seed_offset_gpu
Browse files Browse the repository at this point in the history
Support dropout seed offset as pointer
  • Loading branch information
rocking5566 authored Nov 6, 2024
2 parents 91eb950 + 86f9b1b commit 23cb26b
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 60 deletions.
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 465 files
22 changes: 11 additions & 11 deletions csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

namespace flash {
// Copy from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
} else {
return std::make_tuple(arg.seed_.val, arg.offset_.val);
}
inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state)
{
// Imitate from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
if (arg.captured_) {
rng_state[0] = static_cast<uint64_t>(*arg.seed_.ptr);
rng_state[1] = static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_);
} else {
rng_state[0] = arg.seed_.val;
rng_state[1] = arg.offset_.val;
}
}

inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
Expand Down
25 changes: 13 additions & 12 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
at::Tensor dv,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
// q: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_q = q.stride(0);
Expand Down Expand Up @@ -191,7 +190,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
{drop_seed, drop_offset}};
drop_seed_offset};
}

std::vector<at::Tensor>
Expand All @@ -213,7 +212,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
c10::optional<at::Tensor> &rng_state_)
{
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -337,21 +336,24 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());

uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
at::Tensor rng_state;

if (rng_state.has_value()) {
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
drop_seed = d[0];
drop_offset = d[1];
if (rng_state_.has_value()) {
rng_state = rng_state_.value();
} else if(is_dropout) {
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
}

if (seqlen_q > 0) {
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
ck_tile::stream_config stream_config{stream};

auto traits =
Expand Down Expand Up @@ -380,8 +382,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
dv_expanded,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
drop_seed_offset);

float t = fmha_bwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
Expand Down
20 changes: 8 additions & 12 deletions csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
// q: (batch_size, seqlen_q, nheads, d)
// k: (batch_size, seqlen_k, nheads_k, d)
Expand Down Expand Up @@ -137,7 +136,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
drop_seed_offset};
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -255,24 +254,22 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
p = torch::empty({ 0 }, opts);
}

uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
}

rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));

if (seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};

Expand Down Expand Up @@ -305,8 +302,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
p,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
drop_seed_offset);

float t = fmha_fwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
Expand Down
27 changes: 15 additions & 12 deletions csrc/flash_attn_ck/mha_varlen_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
at::Tensor dv,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
ck_tile::index_t total_q = q.size(0);
ck_tile::index_t total_k = k.size(0);
Expand Down Expand Up @@ -197,7 +196,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
{drop_seed, drop_offset}};
drop_seed_offset};
}

std::vector<at::Tensor>
Expand All @@ -224,7 +223,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
c10::optional<at::Tensor> &rng_state_)
{
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -362,21 +361,26 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());

uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
at::Tensor rng_state;

if (rng_state.has_value()) {
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
drop_seed = d[0];
drop_offset = d[1];
if (rng_state_.has_value()) {
rng_state = rng_state_.value();
} else if(is_dropout) {
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
} else {
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
}

if (max_seqlen_q > 0) {
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
ck_tile::stream_config stream_config{stream};

auto traits =
Expand Down Expand Up @@ -407,8 +411,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
dv_expanded,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
drop_seed_offset);

float t = fmha_bwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
Expand Down
20 changes: 8 additions & 12 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
// q: (total_q, nheads, d)
// k: (total_k, nheads_k, d)
Expand Down Expand Up @@ -140,7 +139,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
drop_seed_offset};
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -281,24 +280,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
if (return_dropout_randval) {p.zero_();}
}

uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
}

rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));

if (max_seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};

Expand Down Expand Up @@ -332,8 +329,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
p,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
drop_seed_offset);

float t = fmha_fwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
Expand Down

0 comments on commit 23cb26b

Please sign in to comment.