Skip to content

Commit

Permalink
Enable fwd and varlen_fwd on AMD (#63)
Browse files Browse the repository at this point in the history
* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up
  • Loading branch information
micmelesse authored Jun 19, 2024
1 parent 320fb59 commit 508a92a
Show file tree
Hide file tree
Showing 7 changed files with 2,000 additions and 10 deletions.
63 changes: 63 additions & 0 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: AMD Perf Kernel Tests

on:
workflow_dispatch:
pull_request:
branches: [main_perf]
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
branches: [main_perf]

concurrency:
group: ${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main_perf' }}

permissions: read-all


jobs:
Runner-Preparation-AMD:
runs-on: ubuntu-latest
timeout-minutes: 30
outputs:
matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Integration-Tests-AMD:
needs: Runner-Preparation-AMD
if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != ''
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
container:
image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Triton
run: |
pip uninstall -y triton
pip install matplotlib pandas pytest
git clone https://github.com/triton-lang/triton
cd triton
pip install --verbose -e python
cd ..
- name: Build
run: |
python setup.py install
- name: Test
run: |
pytest tests/test_flash_attn.py::test_flash_attn_output
pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,11 @@ var/
.idea/

# Dev
venv
venv

# Other
.eggs
.vscode
core
scripts
log*
20 changes: 14 additions & 6 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
import torch
import torch.nn as nn

def is_hip():
if torch.version.hip is not None:
return True
return False

# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
if is_hip():
from . import flash_attn_triton_interface_amd as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

# isort: on

Expand Down Expand Up @@ -48,7 +56,7 @@ def _flash_attn_forward(
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
q,
k,
v,
Expand Down Expand Up @@ -83,7 +91,7 @@ def _flash_attn_varlen_forward(
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
q,
k,
v,
Expand Down Expand Up @@ -130,7 +138,7 @@ def _flash_attn_backward(
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dq, dk, dv, softmax_d, = flash_attn_gpu.bwd(
dout,
q,
k,
Expand Down Expand Up @@ -178,7 +186,7 @@ def _flash_attn_varlen_backward(
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dq, dk, dv, softmax_d, = flash_attn_gpu.varlen_bwd(
dout,
q,
k,
Expand Down Expand Up @@ -1194,7 +1202,7 @@ def flash_attn_with_kvcache(
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
out, softmax_lse = flash_attn_gpu.fwd_kvcache(
q,
k_cache,
v_cache,
Expand Down
Loading

0 comments on commit 508a92a

Please sign in to comment.