Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jun 14, 2024
1 parent 13ba75a commit b6ea085
Showing 1 changed file with 41 additions and 33 deletions.
74 changes: 41 additions & 33 deletions flash_attn/flash_attn_triton_interface_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import torch
import triton

# /////////////////////////////////////////// Interface //////////////////////////////////////////////////////////
DEBUG=False


def fwd(q,
k,
v,
Expand Down Expand Up @@ -35,32 +33,34 @@ def fwd(q,
if dropout_p != 0.0:
raise ValueError("dropout is not supported on HIP")


if o is None:
o = torch.empty_like(q)

# Setup metadata
input_metadata = MetaData(sm_scale=softmax_scale)
input_metadata.max_seqlens_q = q.shape[1]
input_metadata.max_seqlens_k = k.shape[1]
input_metadata.layout = "bshd"

# Create metadata object
metadata = MetaData(sm_scale=softmax_scale)
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k.shape[1]
metadata.layout = "bshd"
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata)

# Setup metadata
if causal:
metadata.need_causal()
input_metadata.need_causal()

# if bias is not None:
# metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2])
# input_metadata.need_bias(bias, batch, nheads_q, input_metadata.max_seqlens_q, input_metadata.max_seqlens_k)

if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, q.shape[0], q.shape[2])
input_metadata.need_alibi(alibi_slopes, batch, nheads_q)

if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
input_metadata.need_dropout(dropout_p, return_softmax)

# Check arguments
metadata.check_args(q, k, v, o)
input_metadata.check_args(q, k, v, o)

# Perform the forward attention computation
tri_out, encoded_softmax = attention(q, k, v, o, metadata)
tri_out, encoded_softmax = attention(q, k, v, o, input_metadata)

softmax_lse = encoded_softmax
softmax_p = encoded_softmax
Expand Down Expand Up @@ -96,28 +96,26 @@ def varlen_fwd(

if dropout_p != 0.0:
raise ValueError("dropout is not supported on HIP")



if o is None:
o = torch.empty_like(q)



# create metadata object
# Setup metadata
input_metadata = MetaData(sm_scale=softmax_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)

# get shapes
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata)

# Setup metadata
if causal:
input_metadata.need_causal()

# if bias is not None:
# metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2])
# input_metadata.need_bias(bias, batch, nheads_q, q.shape[2], k.shape[2])

if alibi_slopes is not None:
input_metadata.need_alibi(alibi_slopes, batch, nheads_q)

if dropout_p > 0.0:
input_metadata.need_dropout(dropout_p, return_softmax)

Expand All @@ -132,7 +130,25 @@ def varlen_fwd(

return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state()

def fwd_kvcache(*args, **kwargs):
def fwd_kvcache(
q,
k_cache,
v_cache,
k,
v,
cache_seqlens,
rotary_cos,
rotary_sin,
cache_batch_idx,
block_table,
alibi_slopes,
out,
softmax_scale,
causal,
window_size_left,
window_size_right,
rotary_interleaved,
num_splits):
pass


Expand All @@ -157,12 +173,9 @@ def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, so
print("gen_:", gen_)
print("rng_state:", rng_state)



if out is None:
out = torch.empty_like(q)


# Ensure the tensors have requires_grad=True
q.requires_grad_()
k.requires_grad_()
Expand All @@ -187,25 +200,20 @@ def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, so
# Setup metadata
if causal:
metadata.need_causal()

# if bias is not None:
# metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2])

return_softmax = True
if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, batch, nheads_q)

if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)

# Check arguments
metadata.check_args(q, k, v, out)



# tri_out, _ = attention(q, k, v, out, metadata)
# tri_out.requires_grad_()
# dout.requires_grad_()
# tri_out.backward(dout)

# write your own version backward
M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) # this passed from

Expand Down

0 comments on commit b6ea085

Please sign in to comment.