Skip to content

Commit

Permalink
Merge pull request #60 from ROCm/micmelesse/enable_fwd
Browse files Browse the repository at this point in the history
Enable fwd and varlen_fwd on AMD
  • Loading branch information
micmelesse authored Jun 19, 2024
2 parents 320fb59 + b6ea085 commit 8fdb1bc
Show file tree
Hide file tree
Showing 7 changed files with 2,000 additions and 10 deletions.
63 changes: 63 additions & 0 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: AMD Perf Kernel Tests

on:
workflow_dispatch:
pull_request:
branches: [main_perf]
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
branches: [main_perf]

concurrency:
group: ${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main_perf' }}

permissions: read-all


jobs:
Runner-Preparation-AMD:
runs-on: ubuntu-latest
timeout-minutes: 30
outputs:
matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Integration-Tests-AMD:
needs: Runner-Preparation-AMD
if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != ''
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
container:
image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Triton
run: |
pip uninstall -y triton
pip install matplotlib pandas pytest
git clone https://github.com/triton-lang/triton
cd triton
pip install --verbose -e python
cd ..
- name: Build
run: |
python setup.py install
- name: Test
run: |
pytest tests/test_flash_attn.py::test_flash_attn_output
pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,11 @@ var/
.idea/

# Dev
venv
venv

# Other
.eggs
.vscode
core
scripts
log*
20 changes: 14 additions & 6 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
import torch
import torch.nn as nn

def is_hip():
if torch.version.hip is not None:
return True
return False

# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
if is_hip():
from . import flash_attn_triton_interface_amd as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

# isort: on

Expand Down Expand Up @@ -48,7 +56,7 @@ def _flash_attn_forward(
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
q,
k,
v,
Expand Down Expand Up @@ -83,7 +91,7 @@ def _flash_attn_varlen_forward(
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
q,
k,
v,
Expand Down Expand Up @@ -130,7 +138,7 @@ def _flash_attn_backward(
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dq, dk, dv, softmax_d, = flash_attn_gpu.bwd(
dout,
q,
k,
Expand Down Expand Up @@ -178,7 +186,7 @@ def _flash_attn_varlen_backward(
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dq, dk, dv, softmax_d, = flash_attn_gpu.varlen_bwd(
dout,
q,
k,
Expand Down Expand Up @@ -1194,7 +1202,7 @@ def flash_attn_with_kvcache(
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
out, softmax_lse = flash_attn_gpu.fwd_kvcache(
q,
k_cache,
v_cache,
Expand Down
Loading

0 comments on commit 8fdb1bc

Please sign in to comment.