Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated Rotary Positional Embeddings (RoPEs) into flash_attn_kvcache #83

Open
wants to merge 26 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f8f42c
Enable fwd and varlen_fwd on AMD (#63)
micmelesse Jun 19, 2024
18ed767
enable flash_attn_with_kvcache (#68)
micmelesse Aug 6, 2024
b2f6523
Fixes after rebase
micmelesse Aug 9, 2024
7b8a15c
enable packed layouts and all configs (#72)
micmelesse Aug 28, 2024
75b5360
Clean up for Upstream (#81)
micmelesse Sep 4, 2024
8945558
feat: pytest case for pytorch implmentation of RoPE
alexkranias-amd Sep 12, 2024
ffecd6d
feat: pytest progress
alexkranias-amd Sep 12, 2024
12c6a68
feat: added MetaData and incorporated Tri Dao RoPE in class _attentio…
alexkranias-amd Sep 17, 2024
15d0683
fix: rotate input_metadata.k_new instead of k (k_cache)
alexkranias-amd Sep 17, 2024
7b83552
feat: added files to gitignore
alexkranias-amd Sep 17, 2024
97f751b
fix: changed batch and head size back to main_perf sizes
alexkranias-amd Sep 18, 2024
2a137f8
test: found a failing test in flash_attn_kvcache
alexkranias-amd Sep 18, 2024
9709453
tests: added debug prints to flash_attn_kvcache tests
alexkranias-amd Sep 23, 2024
75daa99
tests: added debug prints to kvcache tests
alexkranias-amd Sep 23, 2024
eba41a0
tests: found a failing test (no DEBUG)
alexkranias-amd Sep 23, 2024
95e4aa7
test: isolated a failing case
alexkranias-amd Sep 23, 2024
1c05ba7
test: got isolated failing case to pass by reordering when scaling qk…
alexkranias-amd Sep 23, 2024
5923025
test: found deviation in scores
alexkranias-amd Sep 23, 2024
aadf908
test: added prints to see that sum is equivalent
alexkranias-amd Sep 23, 2024
527ecb1
test: reduced to failing test
alexkranias-amd Sep 24, 2024
1b0a841
test: added tests for tl.dot & tl.exp2 with casting
alexkranias-amd Sep 25, 2024
a99aa64
test: added precision error test
alexkranias-amd Sep 26, 2024
1e6796a
chore: removed unnecessary prints and flags
alexkranias-amd Sep 27, 2024
745c864
fix: restored csrc dir
alexkranias-amd Sep 27, 2024
704f976
fix: removed csrc from gitignore and add back to gitmodules
alexkranias-amd Sep 27, 2024
cf49217
fix: fixed fp16 issues
alexkranias-amd Oct 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
name: AMD Perf Kernel Tests

on:
workflow_dispatch:
pull_request:
branches: [main_perf]
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
branches: [main_perf, micmelesse/upstream_pr]

concurrency:
group: ${{ github.ref }}
cancel-in-progress: true

permissions: read-all

jobs:
Runner-Preparation-AMD:
runs-on: ubuntu-latest
timeout-minutes: 30
outputs:
matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi

Integration-Tests-AMD:
needs: Runner-Preparation-AMD
if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != ''
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
container:
image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Triton
run: |
pip uninstall -y triton
pip install matplotlib pandas pytest
git clone https://github.com/triton-lang/triton
cd triton
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
pip install --verbose -e python
cd ..
- name: Build
run: |
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
python setup.py install
- name: Flash Attention Tests
run: |
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
pytest tests/test_flash_attn.py
- name: AMD Kernel Tests
run: |
pytest -v -s flash_attn/flash_attn_triton_kernel_decode_amd.py::test_op_fwd
pytest -v -s flash_attn/flash_attn_triton_kernel_prefill_amd.py
6 changes: 5 additions & 1 deletion .gitignore
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ var/
*.egg-info/
.installed.cfg
*.egg
.eggs

# IDE-related
.idea/

# Dev
venv
venv
.venv
scripts
log
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
[submodule "csrc/composable_kernel"]
path = csrc/composable_kernel
url = https://github.com/ROCm/composable_kernel.git

27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ FlashAttention-2 with CUDA currently supports:
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

### AMD ROCm Support
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.

**Requirements:**
- ROCm 6.0 and above.
Expand All @@ -121,10 +121,33 @@ We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.

FlashAttention-2 with ROCm currently supports:
#### Composable Kernel Backend
FlashAttention-2 ROCm CK backend currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
#### Triton Backend
FlashAttention-2 ROCm Triton backend is a work in progress.
It current supports Forwards only. However some features like PagedAttention and Sliding Window are missing. It can run on both MI and Navi Machines. We are working on backwards.

Inorder to use the triton backend for rocm, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88).

```
git clone https://github.com/triton-lang/triton
cd triton
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
pip install --verbose -e python
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_USE_TRITON_ROCM` set to `"TRUE"`.

```
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
```


## How to use FlashAttention
Expand Down
1 change: 0 additions & 1 deletion csrc/composable_kernel
Submodule composable_kernel deleted from 818297
1 change: 0 additions & 1 deletion csrc/cutlass
Submodule cutlass deleted from 756c35
33 changes: 25 additions & 8 deletions flash_attn/flash_attn_interface.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

import torch
import torch.nn as nn
import os

# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
from flash_attn import flash_attn_triton_interface_amd as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

# isort: on

Expand Down Expand Up @@ -49,7 +54,7 @@ def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
q,
k,
v,
Expand Down Expand Up @@ -87,7 +92,7 @@ def _flash_attn_varlen_forward(
seqused_k=None,
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
q,
k,
v,
Expand Down Expand Up @@ -141,7 +146,7 @@ def _flash_attn_backward(
dk,
dv,
softmax_d,
) = flash_attn_cuda.bwd(
) = flash_attn_gpu.bwd(
dout,
q,
k,
Expand Down Expand Up @@ -195,7 +200,7 @@ def _flash_attn_varlen_backward(
dk,
dv,
softmax_d,
) = flash_attn_cuda.varlen_bwd(
) = flash_attn_gpu.varlen_bwd(
dout,
q,
k,
Expand Down Expand Up @@ -1149,15 +1154,20 @@ def flash_attn_with_kvcache(
v=None,
rotary_cos=None,
rotary_sin=None,
rotary_cos_k=None,
rotary_sin_k=None,
rotary_interleaved=True,
rotary_inplace=False,
rotary_conjugate=False,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
local=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
return_softmax_lse=False,
Expand Down Expand Up @@ -1249,6 +1259,7 @@ def flash_attn_with_kvcache(
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
# assert ALIBI is not ROTARY ?
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
Expand All @@ -1261,7 +1272,7 @@ def flash_attn_with_kvcache(
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
out, softmax_lse = flash_attn_gpu.fwd_kvcache(
q,
k_cache,
v_cache,
Expand All @@ -1270,17 +1281,23 @@ def flash_attn_with_kvcache(
cache_seqlens,
rotary_cos,
rotary_sin,
rotary_cos_k,
rotary_sin_k,
rotary_interleaved,
rotary_inplace,
rotary_conjugate,
cache_seqlens,
cache_batch_idx,
cache_leftpad,
block_table,
alibi_slopes,
None,
softmax_scale,
causal,
local,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
num_splits,
)
return (out, softmax_lse) if return_softmax_lse else out
Loading
Loading