diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml new file mode 100644 index 000000000..14cfa1627 --- /dev/null +++ b/.github/workflows/amd_tests.yml @@ -0,0 +1,68 @@ +name: AMD Perf Kernel Tests + +on: + workflow_dispatch: + pull_request: + branches: [main_perf] + merge_group: + branches: [main_perf] + types: [checks_requested] + push: + branches: [main_perf, micmelesse/upstream_pr] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +permissions: read-all + +jobs: + Runner-Preparation-AMD: + runs-on: ubuntu-latest + timeout-minutes: 30 + outputs: + matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} + steps: + - name: Prepare runner matrix + id: set-matrix + run: | + if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then + echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]' + else + echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]' + fi + + Integration-Tests-AMD: + needs: Runner-Preparation-AMD + if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} + container: + image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install Triton + run: | + pip uninstall -y triton + pip install matplotlib pandas pytest + git clone https://github.com/triton-lang/triton + cd triton + git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88 + pip install --verbose -e python + cd .. + - name: Build + run: | + export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE" + python setup.py install + - name: Flash Attention Tests + run: | + export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE" + pytest tests/test_flash_attn.py + - name: AMD Kernel Tests + run: | + pytest -v -s flash_attn/flash_attn_triton_kernel_decode_amd.py::test_op_fwd + pytest -v -s flash_attn/flash_attn_triton_kernel_prefill_amd.py diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 index c0a6c7cb1..6c304dc1d --- a/.gitignore +++ b/.gitignore @@ -19,9 +19,13 @@ var/ *.egg-info/ .installed.cfg *.egg +.eggs # IDE-related .idea/ # Dev -venv \ No newline at end of file +venv +.venv +scripts +log \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 6216182e7..e9fedd13a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,4 @@ [submodule "csrc/composable_kernel"] path = csrc/composable_kernel url = https://github.com/ROCm/composable_kernel.git + \ No newline at end of file diff --git a/README.md b/README.md index 3e2e066cf..582e186ac 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ FlashAttention-2 with CUDA currently supports: 3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5. ### AMD ROCm Support -ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2. +ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2. **Requirements:** - ROCm 6.0 and above. @@ -121,10 +121,33 @@ We recommend the [Pytorch](https://hub.docker.com/r/rocm/pytorch) container from ROCm, which has all the required tools to install FlashAttention. -FlashAttention-2 with ROCm currently supports: +#### Composable Kernel Backend +FlashAttention-2 ROCm CK backend currently supports: 1. MI200 or MI300 GPUs. 2. Datatype fp16 and bf16 3. Forward's head dimensions up to 256. Backward head dimensions up to 128. +#### Triton Backend +FlashAttention-2 ROCm Triton backend is a work in progress. +It current supports Forwards only. However some features like PagedAttention and Sliding Window are missing. It can run on both MI and Navi Machines. We are working on backwards. + +Inorder to use the triton backend for rocm, follow the steps below. + +First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88). + +``` +git clone https://github.com/triton-lang/triton +cd triton +git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88 +pip install --verbose -e python +``` +Then install and test Flash Attention with the flag `FLASH_ATTENTION_USE_TRITON_ROCM` set to `"TRUE"`. + +``` +export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE" +cd flash-attention +python setup.py install +pytest tests/test_flash_attn.py +``` ## How to use FlashAttention diff --git a/csrc/composable_kernel b/csrc/composable_kernel deleted file mode 160000 index 8182976c3..000000000 --- a/csrc/composable_kernel +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8182976c37433808b5e3a27a6536d1b74b0c23a1 diff --git a/csrc/cutlass b/csrc/cutlass deleted file mode 160000 index 756c351b4..000000000 --- a/csrc/cutlass +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 756c351b4994854b2f8c6dded3821ebbb580876b diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py old mode 100644 new mode 100755 index ecb3515c0..ef7342504 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -4,10 +4,15 @@ import torch import torch.nn as nn +import os # isort: off # We need to import the CUDA kernels after importing torch -import flash_attn_2_cuda as flash_attn_cuda +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + from flash_attn import flash_attn_triton_interface_amd as flash_attn_gpu +else: + import flash_attn_2_cuda as flash_attn_gpu # isort: on @@ -49,7 +54,7 @@ def _flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax ): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( q, k, v, @@ -87,7 +92,7 @@ def _flash_attn_varlen_forward( seqused_k=None, ): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( q, k, v, @@ -141,7 +146,7 @@ def _flash_attn_backward( dk, dv, softmax_d, - ) = flash_attn_cuda.bwd( + ) = flash_attn_gpu.bwd( dout, q, k, @@ -195,7 +200,7 @@ def _flash_attn_varlen_backward( dk, dv, softmax_d, - ) = flash_attn_cuda.varlen_bwd( + ) = flash_attn_gpu.varlen_bwd( dout, q, k, @@ -1149,15 +1154,20 @@ def flash_attn_with_kvcache( v=None, rotary_cos=None, rotary_sin=None, + rotary_cos_k=None, + rotary_sin_k=None, + rotary_interleaved=True, + rotary_inplace=False, + rotary_conjugate=False, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, + local=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated - rotary_interleaved=True, alibi_slopes=None, num_splits=0, return_softmax_lse=False, @@ -1249,6 +1259,7 @@ def flash_attn_with_kvcache( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ + # assert ALIBI is not ROTARY ? assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -1261,7 +1272,7 @@ def flash_attn_with_kvcache( cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) block_table = maybe_contiguous(block_table) - out, softmax_lse = flash_attn_cuda.fwd_kvcache( + out, softmax_lse = flash_attn_gpu.fwd_kvcache( q, k_cache, v_cache, @@ -1270,6 +1281,12 @@ def flash_attn_with_kvcache( cache_seqlens, rotary_cos, rotary_sin, + rotary_cos_k, + rotary_sin_k, + rotary_interleaved, + rotary_inplace, + rotary_conjugate, + cache_seqlens, cache_batch_idx, cache_leftpad, block_table, @@ -1277,10 +1294,10 @@ def flash_attn_with_kvcache( None, softmax_scale, causal, + local, window_size[0], window_size[1], softcap, - rotary_interleaved, num_splits, ) return (out, softmax_lse) if return_softmax_lse else out diff --git a/flash_attn/flash_attn_triton_interface_amd.py b/flash_attn/flash_attn_triton_interface_amd.py new file mode 100755 index 000000000..0c6cbfdde --- /dev/null +++ b/flash_attn/flash_attn_triton_interface_amd.py @@ -0,0 +1,215 @@ +import torch +import triton +from .flash_attn_triton_kernel_prefill_amd import MetaData, get_shape_from_layout, attention_prefill +from .flash_attn_triton_kernel_decode_amd import attention_decode + +def fwd(q, + k, + v, + o, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, + window_size_right, + softcap, + return_softmax, + gen_): + + if dropout_p != 0.0: + raise ValueError("dropout is not supported on AMD's Triton Backend yet") + + if o is None: + o = torch.empty_like(q) + + # Setup metadata + input_metadata = MetaData(sm_scale=softmax_scale) + input_metadata.max_seqlens_q = q.shape[1] + input_metadata.max_seqlens_k = k.shape[1] + input_metadata.layout = "bshd" + if return_softmax: + input_metadata.return_encoded_softmax = True + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata) + + if causal: + input_metadata.need_causal() + + if alibi_slopes is not None: + input_metadata.need_alibi(alibi_slopes, batch, nheads_q) + + if dropout_p > 0.0: + input_metadata.need_dropout(dropout_p, return_softmax) + + # Check arguments + input_metadata.check_args(q, k, v, o) + tri_out, softmax_lse, softmax_dmask= attention_prefill(q, k, v, o, input_metadata) + + return tri_out, q , k , v, o, softmax_lse, softmax_dmask, None + +def bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, + window_size_right, + deterministic, + gen_, + rng_state, +): + raise ValueError("bwd is not supported on AMD's Triton Backend yet") + +def varlen_fwd( + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + block_table_, + alibi_slopes,\ + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + window_size_left, + window_size_right, + softcap, + return_softmax, + gen_): + + if dropout_p != 0.0: + raise ValueError("dropout is not supported on AMD's Triton Backend yet") + + if o is None: + o = torch.empty_like(q) + + # Setup metadata + input_metadata = MetaData(sm_scale=softmax_scale) + if return_softmax: + input_metadata.return_encoded_softmax = True + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata + + # get shapes + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata) + + if causal: + input_metadata.need_causal() + + if alibi_slopes is not None: + input_metadata.need_alibi(alibi_slopes, batch, nheads_q) + + if dropout_p > 0.0: + input_metadata.need_dropout(dropout_p, return_softmax) + + # Check arguments + input_metadata.check_args(q, k, v, o) + + tri_out, softmax_lse, softmax_dmask= attention_prefill(q, k, v, o, input_metadata) + + return tri_out, q , k , v, o, softmax_lse, softmax_dmask, None + +def varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + window_size_left, + window_size_right, + softcap, + deterministic, + gen_, + rng_state, +): + raise ValueError("varlen_bwd is not supported on AMD's Triton Backend yet") + +def fwd_kvcache( + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + rotary_cos, + rotary_sin, + rotary_cos_k, + rotary_sin_k, + rotary_interleaved, + rotary_inplace, + rotary_conjugate, + rotary_seqlen_offsets, + cache_batch_idx, + cache_leftpad, + block_table, + alibi_slopes, + out, + softmax_scale, + causal, + local, + window_size_left, + window_size_right, + softcap, + num_splits): + + if out is None: + out = torch.empty_like(q) + + # fill metadata + input_metadata = MetaData(sm_scale=softmax_scale) + input_metadata.layout = "bshd" + input_metadata.max_seqlens_q = q.shape[1] + input_metadata.max_seqlens_k = k_cache.shape[1] + input_metadata.cache_seqlens = cache_seqlens + input_metadata.cache_batch_idx = cache_batch_idx + + if k is not None and v is not None: + input_metadata.new_kv = True + input_metadata.seqlen_new = k.shape[1] + input_metadata.k_new = k + input_metadata.v_new = v + + if causal: + input_metadata.need_causal() + + if local: + input_metadata.need_local() + + if alibi_slopes is not None: + batch, _ , nheads_q, _= q.shape + input_metadata.need_alibi(alibi_slopes, batch, nheads_q) + + if torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin): + input_metadata.need_rotary(rotary_cos, rotary_sin, rotary_cos_k, rotary_sin_k, rotary_interleaved, rotary_seqlen_offsets, rotary_inplace=rotary_inplace, rotary_conjugate=rotary_conjugate) + + # launch kernel + tri_out, softmax_lse = attention_decode(q, k_cache, v_cache, input_metadata) + return tri_out, softmax_lse diff --git a/flash_attn/flash_attn_triton_kernel_decode_amd.py b/flash_attn/flash_attn_triton_kernel_decode_amd.py new file mode 100755 index 000000000..c16151b05 --- /dev/null +++ b/flash_attn/flash_attn_triton_kernel_decode_amd.py @@ -0,0 +1,954 @@ +import math +from typing import Optional, Union +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb +import pytest +import torch +import sys + +import pdb + +import triton +import triton.language as tl +from flash_attn.flash_attn_triton_kernel_prefill_amd import MetaData + +def _strides(x: torch.Tensor, *stride_names: str): + if x is None: + return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} + + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + K_new, + V_new, + Cache_seqlens, + Cache_batch_idx, + Alibi_slopes, + Rotary_cos, + Rotary_sin, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qd, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kd, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vd, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_d, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_kn_z, + stride_kn_n, + stride_kn_g, + stride_kn_h, + stride_kn_d, + stride_vn_z, + stride_vn_n, + stride_vn_g, + stride_vn_h, + stride_vn_d, + stride_az, + stride_ah, + Z, + N_CTX_Q, + N_CTX_K, + N_CTX_NEW, + BLOCK_N_PER_SPLIT, + H_q: tl.constexpr, + H_kv: tl.constexpr, + G_q: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_CACHE_SEQLENs: tl.constexpr, + USE_CACHE_BATCH_IDX: tl.constexpr, + NEW_KV: tl.constexpr, + IS_GQA: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_ROTARY: tl.constexpr, + ROTARY_INTERLEAVED: tl.constexpr +): + # Padding + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + if PADDED_HEAD: + d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H_q * G_q) # batch + off_h_q = (off_zhg // G_q) % H_q # head + off_g_q = off_zhg % G_q # group (gca / mqa) + splitk_idx = tl.program_id(2) + + # pick batch index + if USE_CACHE_BATCH_IDX: + cache_batch_idx = tl.load(Cache_batch_idx + off_z) + else: + cache_batch_idx = off_z + + # Load ALiBi slope if enabled + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(Alibi_slopes + a_offset) + else: + alibi_slope = None + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_CACHE_SEQLENs: + cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) + if NEW_KV: + kv_len = cache_seqlen_last_idx + N_CTX_NEW + else: + kv_len = cache_seqlen_last_idx + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + HEAD_RATIO: tl.constexpr = H_q // H_kv + if IS_GQA: + k_head_idx = off_h_q // HEAD_RATIO + v_head_idx = k_head_idx + else: + k_head_idx = off_h_q + v_head_idx = off_h_q + + # calculate base offset + k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg + v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg + + # Copy new Keys and Values into Cache + if NEW_KV: + knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g + + # Determine the starting position for new data in the cache + if USE_CACHE_SEQLENs: + start_idx = tl.load(Cache_seqlens + off_z) + else: + start_idx = N_CTX_K - N_CTX_NEW + + # Copy new Keys + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from K_new + k_new_block = tl.load( + knew_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + + (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to K + tl.store( + k_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + + (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, + k_new_block, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + ) + + # Copy new Values + vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from V_new + v_new_block = tl.load( + vnew_base + + (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to V + tl.store( + v_base + + (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, + v_new_block, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + ) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg, + shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qd), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + K_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(ACTUAL_BLOCK_DMODEL, hi), + strides=(stride_kd, stride_kn), + offsets=(0, lo), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, ACTUAL_BLOCK_DMODEL), + strides=(stride_vn, stride_vd), + offsets=(lo, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2_e = 1.44269504 + qk_scale = sm_scale * log2_e + + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), + boundary_check=(0, ) + ) + q = (q * qk_scale) + if PADDED_HEAD: + q = tl.where(d_mask[None, :], q, 0.0) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + 1, + BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL, + Q.dtype.element_ty, + 0, + ) + if PADDED_HEAD: + k = tl.where(d_mask[:, None], k, 0.0) + v = tl.where(d_mask[None, :], v, 0.0) + + if USE_ROTARY: + # rotate q and k before dot product + pass + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + if USE_ALIBI: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * log2_e) + + # Apply causal mask if IS_CAUSAL is True + if IS_CAUSAL: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_Q - kv_len + causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + if IS_CAUSAL: + alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) + else: + alpha = tl.math.exp2(m_i - m_i_new) + # cause of nan because subtracting infs + if IS_CAUSAL: + qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) + else: + qk = qk - m_i_new[:, None] + + p = tl.math.exp2(qk) # p = e^(qk^T) + + # -- update m_i (current max) and l_i (sum of elements) -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(v.dtype), v) + + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, BLOCK_DMODEL), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0, ), + ) + # Write metadata for split-K reduction + Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + + tl.arange(0, BLOCK_M)) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + + +@triton.jit +def load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()).to(tl.float32) + v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()).to(tl.float32) + + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + use_mask: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) + + o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + + if IS_CAUSAL: + l_m_offset = l_m - g_m + alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) + else: + alpha = tl.math.exp2(l_m - g_m) + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + + if IS_CAUSAL: + # Avoid division by zero + g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) + acc_out = tl.sum(acc, axis=0) / g_sum_safe + else: + acc_out = tl.sum(acc, axis=0) / g_sum + + # Store output + Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + + off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) + tl.store(Out_ptr, acc_out) + + # log constant + log2_e = 1.44269504 + + # Store lse + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + if IS_CAUSAL: + lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / log2_e, g_m) + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / log2_e) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[..., 0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[..., 0:2].view(torch.float16) + shift = scale_shift_ui8[..., 2:4].view(torch.float16) + + kv_ui8 = k_ui8[..., ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[..., ::2] = k1_f16 + out[..., 1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + +def get_padded_headsize(size): + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + return padded_d_model + +class _attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } + SUPPORTED_MAX_K = 128 + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_metadata: MetaData): + original_layout = input_metadata.layout + + # Rotary Embedding Implementation + if torch.is_tensor(input_metadata.rotary_cos) and torch.is_tensor(input_metadata.rotary_sin): + if input_metadata.causal or input_metadata.local: + q_ro = apply_rotary_emb( + q, + input_metadata.rotary_cos, + input_metadata.rotary_sin, + seqlen_offsets=input_metadata.cache_seqlens, + interleaved=input_metadata.rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + input_metadata.rotary_cos, + input_metadata.rotary_sin, + seqlen_offsets=input_metadata.cache_seqlens, + interleaved=input_metadata.rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=input_metadata.max_seqlens_q, + ) + k_ro = apply_rotary_emb( + input_metadata.k_new, + input_metadata.rotary_cos, + input_metadata.rotary_sin, + seqlen_offsets=input_metadata.cache_seqlens, + interleaved=input_metadata.rotary_interleaved, + ) + + q, input_metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) + + # kernels expects "bsghd" + if input_metadata.layout == "bshd": + q=q.unsqueeze(2) + k=k.unsqueeze(2) + v=v.unsqueeze(2) + + if input_metadata.new_kv: + input_metadata.k_new = input_metadata.k_new.unsqueeze(2) + input_metadata.v_new = input_metadata.v_new.unsqueeze(2) + + input_metadata.layout = "bsghd" + elif input_metadata.layout == "bhsd": + q=q.permute(0, 2, 1, 3).unsqueeze(2) + k=k.permute(0, 2, 1, 3).unsqueeze(2) + v=v.permute(0, 2, 1, 3).unsqueeze(2) + if input_metadata.new_kv: + input_metadata.k_new = input_metadata.k_new.permute(0, 2, 1, 3).unsqueeze(2) + input_metadata.v_new = input_metadata.v_new.permute(0, 2, 1, 3).unsqueeze(2) + + + input_metadata.layout = "bsghd" + elif input_metadata.layout == "bsghd": + pass + elif input_metadata.layout is None: + raise ValueError("Layout not given") + + assert input_metadata.layout == "bsghd" + + # get dims + batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape + _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape + _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape + + assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}" + + # get padded size + dim_padded = get_padded_headsize(dim_k) + + # Handle MQA/GQA case + if heads_per_group_q > heads_per_group_k: + input_metadata.is_gqa = True + elif heads_per_group_q < heads_per_group_k: + raise ValueError("heads_per_group_q < heads_per_group_k") + else: + input_metadata.is_gqa = False + + # context + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_QUANT_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + if input_metadata.cache_seqlens is not None: + cache_seqlens = input_metadata.cache_seqlens + else: + cache_seqlens = None + + assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens? + + seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M + out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) + metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) + lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k) + + num_warps = 1 + split_size = (seqlen_k + split_k - 1) // split_k + use_cache_seqlens = cache_seqlens is not None + + # TODO: enable quantization + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=input_metadata.sm_scale, + Out_splitK=out_splitk, + Metadata=metadata, + K_new = input_metadata.k_new, + V_new = input_metadata.v_new, + Cache_seqlens=cache_seqlens, + Cache_batch_idx=input_metadata.cache_batch_idx, + Alibi_slopes=input_metadata.alibi_slopes, + Rotary_cos = input_metadata.rotary_cos, + Rotary_sin = input_metadata.rotary_sin, + **_strides(q, "qz", "qm", "qg", "qh", "qd"), + **_strides(k, "kz", "kn", "kg", "kh", "kd"), + **_strides(v, "vz", "vn", "vg", "vh", "vd"), + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(input_metadata.k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"), + **_strides(input_metadata.v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"), + **_strides(input_metadata.alibi_slopes, "az", "ah"), + Z=batch_size, + H_q=heads_per_group_q, + H_kv=heads_per_group_k, + G_q=n_group_q, + N_CTX_Q=seqlen_q, + N_CTX_K=seqlen_k, + N_CTX_NEW=input_metadata.k_new.shape[1] if input_metadata.new_kv else None, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_k, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, + USE_CACHE_SEQLENs=use_cache_seqlens, + USE_CACHE_BATCH_IDX= input_metadata.cache_batch_idx is not None, + NEW_KV=input_metadata.new_kv, + IS_GQA=input_metadata.is_gqa, + IS_CAUSAL=input_metadata.causal, + USE_ALIBI=False if input_metadata.alibi_slopes is None else True, + USE_ROTARY=False if input_metadata.rotary_cos is None or input_metadata.rotary_sin is None else True, + ROTARY_INTERLEAVED = True if input_metadata.rotary_interleaved else False, + num_warps=num_warps, + num_stages=1, + ) + + out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert dim_padded % k_block_num == 0 + k_block_size = dim_padded // k_block_num + grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) + + _splitK_reduce[grid]( + out_splitk, + metadata, + out, + lse, + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=seqlen_q_ceil, + BLOCK_SIZE=k_block_size, + G=n_group_q, + H=heads_per_group_q, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + IS_CAUSAL=input_metadata.causal, + num_warps=4) + + lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q]) + if q.ndim == 4: + # BMGHK -> BMHK + assert n_group_q == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if seqlen_k == 0: + out.zero_() + out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous() + + # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q + if original_layout == "bshd": + # out=out.transpose(1, 2).contiguous() # this screws up heads and data. + # the data is laid out properly. Just need to reshape dims + out = out.reshape(batch_size, seqlen_q, -1, dim_padded) + + return out.narrow(-1, 0, dim_k), lse + + +attention_decode = _attention.apply + + +def get_input_shapes(): + cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128) + for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)] + + return cases + + +@pytest.mark.parametrize('batch_size, seqlen_q, seqlen_k, group_q, group_k, dim', get_input_shapes()) +def test_op_fwd(batch_size, seqlen_q, seqlen_k, group_q, group_k, dim, dtype=torch.bfloat16): + print() + print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, group_q = {group_q}, group_k = {group_k}, dim = {dim}") + torch.manual_seed(20) + query_group_head_size = (group_q + group_k - 1) // group_k + q = (torch.empty((batch_size, seqlen_q, group_k, query_group_head_size, dim), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + k = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype, + device="cuda").normal_(mean=0., + std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1) + v = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype, + device="cuda").normal_(mean=0., + std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1) + scale = 1 / dim**0.5 + input_metadata = MetaData(sm_scale=scale) + input_metadata.layout = "bsghd" + tri_out, _ = attention_decode(q, k, v, input_metadata) + + q = q.reshape([batch_size, seqlen_q, -1, dim]).permute(0, 2, 1, 3) + k = k.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3) + v = v.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0) + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes()) +def test_op_fwd_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(2) + q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, + device="cuda").normal_(mean=1.0, std=0.5).requires_grad_()) + k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=1.0, + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=1.0, + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + + num_groups = 1 + quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32)) + quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32)) + scale = 1 / K**0.5 + input_metadata = MetaData(sm_scale=scale) + input_metadata.layout = "bsghd" + tri_out, _ = attention_decode(q, quant_k, quant_v, input_metadata) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) + + # since quantization introduces rounding error, use the + # dequantized kv as inputs to the ref implementation to reduce + # the tolerance to 1e-3 + dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) + dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) + dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) + dq_ref_out = dq_attn @ dqv + torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + + +def test_quantization(): + a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') + qa = quantize_kv_int4(a, num_groups=4) + dqa = dequantize_kv_fp16(qa, num_groups=4) + torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) + + +try: + FLASH_VER = 2 +except BaseException: + try: + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +configs = [] +for mode in ['fwd']: + # for D_HEAD in [128]: + for causal in [False]: + configs.append( + triton.testing.Benchmark( + x_names=['B', 'Mq', 'Mkv', 'Hq', 'Hkv', 'K'], x_vals=get_input_shapes(), line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), styles=[('red', '-'), + ('blue', '-')], + ylabel='ms', plot_name=f'fused-attention-d{128}-{mode}-causal={causal}', args={ + # 'D_HEAD': D_HEAD, + 'dtype': torch.float16, 'mode': mode, 'causal': causal + })) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(B, Mq, Mkv, Hq, Hkv, K, causal, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 100 + rep = 400 + ms = 0 + if provider == "triton": + q = torch.randn([B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=False) + k = torch.randn([B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, + requires_grad=False).expand(-1, -1, -1, Hq // Hkv, -1) + v = torch.randn([B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, + requires_grad=False).expand(-1, -1, -1, Hq // Hkv, -1) + + sm_scale = 1.3 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.layout = "bsghd" + fn = lambda: attention_decode(q, k, v, input_metadata) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + # flops_per_matmul = 2 * B * Hq * (Mq * K * Mkv + Mq * Mkv * K) + # total_flops = 2 * flops_per_matmul + # totalBytes = ((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2 + + # return totalBytes / ms * 1e-9 + return ms * 1000 + + +def main(): + bench_flash_attention.run(save_path='.', print_data=True) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/flash_attn/flash_attn_triton_kernel_prefill_amd.py b/flash_attn/flash_attn_triton_kernel_prefill_amd.py new file mode 100755 index 000000000..44f3f3572 --- /dev/null +++ b/flash_attn/flash_attn_triton_kernel_prefill_amd.py @@ -0,0 +1,1597 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf + +Credits: +AMD Triton kernels team +OpenAI kernel team + +Currently only the forward kernel is supported, and contains these features: + +1) Fwd with causal masking +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias + +""" + +import argparse +import pytest +import sys +import torch + +import triton +import triton.language as tl + + +class MetaData(): + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + local = False + num_contexts = 0 + varlen = False + layout = None + rotary_cos = None + rotary_sin = None + rotary_cos_k = None # cos/sin_k is cos/sin meant specifically for vector k. It's when we want to deliberately rotate each vector independently? + rotary_sin_k = None + rotary_interleaved = False + rotary_seqlen_offsets = None + rotary_inplace = False + rotary_conjugate = False + cache_seqlens = None + cache_batch_idx = None + new_kv = False + seqlen_new = None + k_new = None + v_new = None + dropout_p, return_encoded_softmax = 0.0, False + + def __repr__(self) -> str: + return (f"MetaData(\n" + f" sm_scale={self.sm_scale},\n" + f" cu_seqlens_q={self.cu_seqlens_q},\n" + f" cu_seqlens_k={self.cu_seqlens_k},\n" + f" max_seqlens_q={self.max_seqlens_q},\n" + f" max_seqlens_k={self.max_seqlens_k},\n" + f" bias={self.bias},\n" + f" alibi_slopes={self.alibi_slopes},\n" + f" causal={self.causal},\n" + f" local={self.local},\n" + f" num_contexts={self.num_contexts},\n" + f" varlen={self.varlen},\n" + f" layout={self.layout},\n" + f" rotary_cos={self.rotary_cos},\n" + f" rotary_sin={self.rotary_sin},\n" + f" rotary_cos={self.rotary_cos_k},\n" + f" rotary_sin={self.rotary_sin_k},\n" + f" rotary_interleaved={self.rotary_interleaved},\n" + f" rotary_seqlen_offsets={self.rotary_seqlen_offsets},\n" + f" rotary_inplace={self.rotary_inplace},\n" + f" rotary_conjugate={self.rotary_conjugate},\n" + f" cache_seqlens={self.cache_seqlens},\n" + f" cache_batch_idx={self.cache_batch_idx},\n" + f" new_kv={self.new_kv},\n" + f" seqlen_new={self.seqlen_new},\n" + f" k_new={self.k_new},\n" + f" v_new={self.v_new},\n" + f" dropout_p={self.dropout_p},\n" + f" return_encoded_softmax={self.return_encoded_softmax}\n" + f")") + + def __init__(self, sm_scale=1.0): + self.sm_scale = sm_scale + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.layout = 'thd' + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + + def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_rotary(self, rotary_cos, rotary_sin, rotary_cos_k, rotary_sin_k, rotary_interleaved, rotary_seqlen_offsets, rotary_inplace, rotary_conjugate): + assert rotary_cos.is_cuda + assert rotary_sin.is_cuda + assert rotary_cos.shape == rotary_sin.shape, "rotary_sin and rotary_cos shapes must match" + assert rotary_cos.dim() == 2 and rotary_sin.dim() == 2 + assert rotary_interleaved is not None + self.rotary_cos = rotary_cos + self.rotary_sin = rotary_sin + self.rotary_cos_k = rotary_cos_k + self.rotary_sin_k = rotary_sin_k + self.rotary_interleaved = rotary_interleaved + self.rotary_seqlen_offsets = rotary_seqlen_offsets + self.rotary_inplace = rotary_inplace + self.rotary_conjugate = rotary_conjugate + + + def need_causal(self): + self.causal = True + + def need_local(self): + self.local = True + + def need_dropout(self, dropout_p, return_encoded_softmax): + self.dropout_p = dropout_p + self.return_encoded_softmax = return_encoded_softmax + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, self) + if self.varlen: + assert q.dim() == 3 + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + # TODO:Remove once dropout is supported with varlen + assert self.dropout_p == 0.0 + # assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen + assert not (torch.is_tensor(self.rotary_cos) and not torch.is_tensor(self.rotary_sin)), "rotary_sin and rotary_cos must either both be None or both be Tensors" + assert not (torch.is_tensor(self.rotary_cos_k) and not torch.is_tensor(self.rotary_sin_k)), "rotary_sin_k and rotary_cos_k must either both be None or both be Tensors" + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. +@triton.jit +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + if PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + + # -- compute qk ---- + qk += tl.dot(q, k) + + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, tl.where(keep, p, -p)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, p) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + if RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += BLOCK_N + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # TODO: This configs fails with head_size not pow2 with data mismatches. figure out why + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + use_cuda_graph=True, +) +@triton.jit +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, + stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, + stride_sz, stride_sh, stride_sm, stride_sn, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + if ENABLE_DROPOUT: + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + if RETURN_ENCODED_SOFTMAX: + encoded_sm_offset = encoded_softmax + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm + encoded_sm_ptrs = encoded_sm_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + else: + encoded_sm_ptrs = None + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + q = (q * qk_scale).to(q.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, alibi_slope, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + +@triton.jit +def _attn_bwd_preprocess( + Out, + DO, + Delta, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_doz, + stride_doh, + stride_dom, + stride_don, + seqlen_q, + head_dim, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + # off_n = tl.arange(0, D_HEAD) + off_m = tl.program_id(0) * BLOCK_M + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + o_offset = off_h * stride_oh + off_z * stride_oz + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, head_dim), strides=(stride_om, stride_on), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + do_offset = off_h * stride_doh + off_z * stride_doz + DO_block_ptr = tl.make_block_ptr(base=DO + do_offset, shape=(seqlen_q, head_dim), strides=(stride_dom, stride_don), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + # load + # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + do = tl.load(DO_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) + off_zh = off_z * num_h + off_h * 1 + # Check for OOB accesses + delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M) + overflow = off_m + BLOCK_M - seqlen_q + if overflow > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32) + mask = boundary > tl.arange(0, BLOCK_M) + tl.store(delta_ptrs, delta, mask=mask) + else: + tl.store(delta_ptrs, delta) + + +@triton.jit +def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + # offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr(base=Q, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_M1), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M1, BLOCK_DMODEL), order=(1, 0)) + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(QT_block_ptr) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + kqT = tl.dot(k, qT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True) + kqT += alibi_block * 1.44269504089 + + pT = tl.math.exp2(kqT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(DO_block_ptr) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) + DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) + return dk, dv + + +@triton.jit +def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + # offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr(base=K, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + VT_block_ptr = tl.make_block_ptr(base=V, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(KT_block_ptr) + qk = tl.dot(q, kT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n) + qk += alibi_block * 1.44269504089 + + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + vT = tl.load(VT_block_ptr) + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ.0. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) + VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + # H = 16, N_CTX = 1024 + H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # offs_k = tl.arange(0, BLOCK_DMODEL) + + start_n = pid * BLOCK_N1 + # This assignment is important. It is what allows us to pick the diagonal + # blocks. Later, when we want to do the lower triangular, we update start_m + # after the first dkdv call. + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + + # load K and V: they stay in SRAM throughout the inner loop for dkdv. + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + + if USE_ALIBI: + a_offset = bhid + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + # compute dK and dV for blocks close to the diagonal that need to be masked + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + + # compute dK and dV for blocks that don't need masking further from the diagonal + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + + DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DV_block_ptrs, dv.to(v.dtype)) + + # Write back dK. + dk *= sm_scale + DK_block_ptrs = tl.make_block_ptr(base=DK, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DK_block_ptrs, dk.to(k.dtype)) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + Q_block_ptr = tl.make_block_ptr(base=Q, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, MASK_BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, MASK=True) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * BLOCK_N2, num_steps, MASK=False) + # Write back dQ. + DQ_block_ptr = tl.make_block_ptr(base=DQ, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + dq *= LN2 + tl.store(DQ_block_ptr, dq.to(q.dtype)) + +def get_shape_from_layout(q, k, metadata): + if metadata.layout == 'thd': + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] + batch = metadata.num_contexts + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + else: + assert False, "Got unsupported layout." + return batch, nheads_q, nheads_k, head_size + + +# TODO: This can probably optimized to have fewer lines of code. +def get_strides_from_layout(q, k, v, o, metadata): + if metadata.layout == 'thd': + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + elif metadata.layout == 'bhsd': + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + elif metadata.layout == 'bshd': + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + else: + assert False, 'Got unsupported layout.' + return q_strides, k_strides, v_strides, o_strides + + +class _attention_prefill(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, o, metadata): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert (metadata.bias.numel() < 2**31) + + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + metadata.check_args(q, k, v, o) + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + + grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + + # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according + # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing + # only. This return holds no useful output aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros((batch, nheads_q, metadata.max_seqlens_q, metadata.max_seqlens_k), device=q.device, + dtype=torch.float32) + softmax_strides = (encoded_softmax.stride(0), encoded_softmax.stride(1), encoded_softmax.stride(2), + encoded_softmax.stride(3)) + else: + encoded_softmax = None + softmax_strides = (0, 0 , 0 , 0) + + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), metadata.bias.stride(2), + metadata.bias.stride(3)) + else: + bias_strides = (0, 0, 0, 0) + + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, *alibi_strides, *softmax_strides, metadata.cu_seqlens_q, metadata.cu_seqlens_k, + dropout_p=metadata.dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, + MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, + BLOCK_DMODEL=padded_d_model, USE_BIAS=False if metadata.bias is None else True, + USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p + > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = metadata.sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes + ctx.dropout_p = metadata.dropout_p + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = metadata.return_encoded_softmax + return o, M, encoded_softmax + + @staticmethod + def backward(ctx, do, _): # expects bhsd + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + seqlen_q = q.shape[2] + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + # NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + assert N_CTX % PRE_BLOCK == 0 + delta = torch.empty_like(M) + _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] + # padded_head = (Lk != ctx.BLOCK_DMODEL) + grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + + _attn_bwd_preprocess[grid_preprocess]( + o, + do, + delta, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + do.stride(0), + do.stride(1), + do.stride(2), + do.stride(3), + seqlen_q, + head_dim=Lk, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) + + _attn_bwd[grid]( + q, + arg_k, + v, + ctx.sm_scale, + ctx.alibi_slopes, + do, + dq, + dk, + dv, + M, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + N_HEAD, + N_CTX, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + USE_ALIBI=False if ctx.alibi_slopes is None else True, + ) + + return dq, dk, dv, M, None + + +attention_prefill = _attention_prefill.apply + + +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout): + torch.manual_seed(20) + + # Initialize q, k, v + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + else: + assert False, 'Got unsupported tensor layout' + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout + return q, k, v, input_metadata + + +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False): + torch.manual_seed(20) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) + else: + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) +def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): + torch.manual_seed(20) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if causal: + input_metadata.need_causal() + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, HQ) + else: + alibi_slopes = None + + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 16, 15498, 2, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + # TODO: This config fails. Disabled until triaged and fixed. + # (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_bias', [True]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): + torch.manual_seed(20) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd') + if causal: + input_metadata.need_causal() + if use_bias: + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") + input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) + else: + bias = None + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) + # reference implementation:171 + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_bias: + scores += input_metadata.bias + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64), + (4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64), + (4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128), + (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) + + tri_out = torch.empty_like(q) + ref_out = torch.empty_like(q) + + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) + attention_prefill(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), + (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), + (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), + (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), + (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) + ref_out = torch.empty_like(q) + tri_out = torch.empty_like(q) + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the + # size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_ref[start_k:end_k] + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + attention_prefill(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + # (1, 1, 1, 16) + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (2, 48, 4096, 64), + (1, 16, 1024, 64), + (1, 16, 1024, 128), + #(1, 16, 8192, 63), + #(1, 16, 1022, 64), +]) +@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) +@pytest.mark.parametrize('torch_sdpa_test', [False, True]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [False, True]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, + dtype=torch.float16): + pytest.skip("Prefill Backward Kernel is broken") + torch.manual_seed(20) + if qseqlen_not_equal_kseqlen is not None: + seqlen_q = qseqlen_not_equal_kseqlen + else: + seqlen_q = N_CTX + seqlen_k = N_CTX + + if causal and ((N_CTX - 1) & N_CTX): + pytest.skip() + if causal and seqlen_q != seqlen_k: + pytest.skip() + + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = seqlen_q + input_metadata.max_seqlens_k = seqlen_k + input_metadata.layout = "bhsd" + + dropout_p = 0 + q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + o = torch.empty_like(q) + + if causal: + input_metadata.need_causal() + + if use_alibi and not torch_sdpa_test: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, H) + dout = torch.randn_like(q) + + # reference implementation + if torch_sdpa_test: + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, + is_causal=causal, scale=sm_scale, + dropout_mask=None) + ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + else: + M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if use_alibi: + p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) + if causal: + p[:, :, M == 0] = float("-inf") + + p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # # triton implementation + tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + if DEBUG: + print("tri_out:", tri_out) + print("ref_out:",ref_out ) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + # The current block size for MI200 series is 64x64. This results in + # larger differences in float results due to rounding. + + if dtype == torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + if dtype == torch.float32: + ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + else: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + + RTOL = 0 + + torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) + + +def nonvarlen_benchmark_configs(): + configs = [ + (16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + (2, 48, 48, 2048, 1024), + (2, 48, 48, 4096, 8192), + (2, 48, 48, 8192, 4096), + (2, 48, 48, 16384, 8192), + (8, 16, 16, 1989, 15344), + (4, 16, 16, 4097, 163), + (2, 16, 16, 8122, 2159), + (1, 16, 16, 16281, 7), + (2, 48, 48, 1021, 1020), + (2, 48, 48, 2001, 2048), + (2, 48, 48, 3996, 9639), + (2, 48, 48, 8181, 1021), + ] + return configs + + +def varlen_benchmark_configs(): + configs = [ + (2, 16, 4, 1024, 1024), + (8, 16, 2, 2048, 2048), + (4, 16, 8, 4096, 4096), + (2, 16, 4, 8192, 8192), + (2, 16, 8, 16384, 16384), + (2, 48, 12, 1024, 1024), + (2, 48, 24, 2048, 2048), + (2, 48, 8, 4096, 4096), + (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + (2, 64, 32, 1024, 1024), + (4, 64, 16, 2048, 2048), + (4, 64, 8, 4096, 4096), + (4, 64, 32, 8192, 8192), + (4, 128, 16, 16384, 16384), + ] + return configs + + +def run_benchmark(custom, args): + + dtype = arg_to_torch_dtype[args.dtype] + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk + head_size = 128 if not args.d else args.d + mode = 'fwd' + x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] + causal = args.causal + varlen = args.layout == 'thd' + configs = [] + if custom: + x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] + else: + if varlen: + x_vals_list = varlen_benchmark_configs() + else: + x_vals_list = nonvarlen_benchmark_configs() + print_time = args.return_time + line_names = 'Time (ms)' if print_time else 'TFLOPS' + configs.append( + triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], + line_names=[line_names], styles=[('red', '-')], ylabel='ms', + plot_name=f'fused-attention-{mode}-d{head_size}-layout{args.layout}', + args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode})) + + @triton.testing.perf_report(configs) + def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda"): + assert mode in ["fwd", "bwd"] + warmup = 25 + rep = 100 + # TODO: Enable bias after testing. + # if use_bias: + # bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + # input_metadata.need_bias(bias, BATCH, H, N_CTX, N_CTX) + # else: + # bias = None + # bias = None + + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + + flops_per_matmul = 0 + if varlen: + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, + args.equal_seqlens) + for i in range(0, input_metadata.num_contexts): + seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] + seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] + # x2 for 2 GEMMs + flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 + else: + q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout) + flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + if causal: + input_metadata.need_causal() + o = torch.empty_like(q) + fn = lambda: attention_prefill(q, k, v, o, input_metadata) + if mode == 'bwd': + o, _, _= fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + total_flops = 2 * flops_per_matmul + # TODO: This needs to be fixed for unequal Q/K seqlens + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + if print_time: + return ms + else: + return total_flops / ms * 1e-9 + + bench_flash_attention.run(save_path=".", print_data=True) + + +def supported_layouts(): + layouts = \ + 'bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]' \ + 'bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]' \ + 'thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]' \ + 'This layout is sometimes called "varlen" or "grouped" layout.' + return layouts + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Benchmark FlashAttention", + allow_abbrev=False, + ) + parser.add_argument("-b", type=int, default=0) + parser.add_argument("-hq", type=int, default=0) + parser.add_argument("-hk", type=int, default=0) + parser.add_argument("-sq", type=int, default=0) + parser.add_argument("-sk", type=int, default=0) + parser.add_argument("-equal_seqlens", action='store_true', default=False, + help='If specified, each context within the thd layout' \ + ' has same seqlen as sq and sk') + parser.add_argument("-d", type=int, default=0) + parser.add_argument("-causal", action='store_true', default=False) + parser.add_argument("-dtype", default='fp16') + parser.add_argument("-return_time", action='store_true', default=False) + parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) + return parser.parse_args() + + +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + + +def main(): + args = parse_args() + custom_config = False + assert args.layout == 'thd' or not args.equal_seqlens, \ + "Equal sequence lengths arg must be used with the thd layout." + if args.b or args.hq or args.hk or args.sq or args.sk or args.d: + custom_config = True + assert args.b and args.hq and args.sq and args.d, \ + "If custom config is specified, please provide \ + all of batch, number of Q heads, Q sequence length \ + and head size." + + assert args.dtype in arg_to_torch_dtype, \ + "Only fp16, bf16 and f32 types currently supported." + + run_benchmark(custom_config, args) + + +if __name__ == '__main__': + sys.exit(main()) \ No newline at end of file diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py old mode 100644 new mode 100755 diff --git a/setup.py b/setup.py index fd67f645b..4228fc01d 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" - +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE" def get_platform(): """ @@ -313,81 +313,85 @@ def validate_and_update_archs(archs): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h - # See https://github.com/pytorch/pytorch/pull/70650 - generator_flag = [] - torch_dir = torch.__path__[0] - if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - - check_if_rocm_home_none("flash_attn") - cc_flag = [] - - archs = os.getenv("GPU_ARCHS", "native").split(";") - validate_and_update_archs(archs) - - cc_flag = [f"--offload-arch={arch}" for arch in archs] - - # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as - # torch._C._GLIBCXX_USE_CXX11_ABI - # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 - if FORCE_CXX11_ABI: - torch._C._GLIBCXX_USE_CXX11_ABI = True - - sources = ["csrc/flash_attn_ck/flash_api.cpp", - "csrc/flash_attn_ck/mha_bwd.cpp", - "csrc/flash_attn_ck/mha_fwd.cpp", - "csrc/flash_attn_ck/mha_varlen_bwd.cpp", - "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( - f"build/fmha_*wd*.cpp" - ) - - rename_cpp_to_cu(sources) - - renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", - "csrc/flash_attn_ck/mha_bwd.cu", - "csrc/flash_attn_ck/mha_fwd.cu", - "csrc/flash_attn_ck/mha_varlen_bwd.cu", - "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") - extra_compile_args = { - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": - [ - "-O3","-std=c++17", - "-mllvm", "-enable-post-misched=0", - "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", - "-fgpu-flush-denormals-to-zero", - "-DCK_ENABLE_BF16", - "-DCK_ENABLE_BF8", - "-DCK_ENABLE_FP16", - "-DCK_ENABLE_FP32", - "-DCK_ENABLE_FP64", - "-DCK_ENABLE_FP8", - "-DCK_ENABLE_INT8", - "-DCK_USE_XDL", - "-DUSE_PROF_API=1", - "-D__HIP_PLATFORM_HCC__=1", - # "-DFLASHATTENTION_DISABLE_BACKWARD", - ] - + generator_flag - + cc_flag - , - } - - include_dirs = [ - Path(this_dir) / "csrc" / "composable_kernel" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", - ] + if USE_TRITON_ROCM: + # Skip C++ extension compilation if using Triton Backend + pass + else: + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pytorch/pytorch/pull/70650 + generator_flag = [] + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + check_if_rocm_home_none("flash_attn") + cc_flag = [] + + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) + + cc_flag = [f"--offload-arch={arch}" for arch in archs] + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + sources = ["csrc/flash_attn_ck/flash_api.cpp", + "csrc/flash_attn_ck/mha_bwd.cpp", + "csrc/flash_attn_ck/mha_fwd.cpp", + "csrc/flash_attn_ck/mha_varlen_bwd.cpp", + "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( + f"build/fmha_*wd*.cpp" + ) - ext_modules.append( - CUDAExtension( - name="flash_attn_2_cuda", - sources=renamed_sources, - extra_compile_args=extra_compile_args, - include_dirs=include_dirs, + rename_cpp_to_cu(sources) + + renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", + "csrc/flash_attn_ck/mha_bwd.cu", + "csrc/flash_attn_ck/mha_fwd.cu", + "csrc/flash_attn_ck/mha_varlen_bwd.cu", + "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": + [ + "-O3","-std=c++17", + "-mllvm", "-enable-post-misched=0", + "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", + "-fgpu-flush-denormals-to-zero", + "-DCK_ENABLE_BF16", + "-DCK_ENABLE_BF8", + "-DCK_ENABLE_FP16", + "-DCK_ENABLE_FP32", + "-DCK_ENABLE_FP64", + "-DCK_ENABLE_FP8", + "-DCK_ENABLE_INT8", + "-DCK_USE_XDL", + "-DUSE_PROF_API=1", + "-D__HIP_PLATFORM_HCC__=1", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + ] + + generator_flag + + cc_flag + , + } + + include_dirs = [ + Path(this_dir) / "csrc" / "composable_kernel" / "include", + Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", + Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", + ] + + ext_modules.append( + CUDAExtension( + name="flash_attn_2_cuda", + sources=renamed_sources, + extra_compile_args=extra_compile_args, + include_dirs=include_dirs, + ) ) - ) def get_package_version(): diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py old mode 100644 new mode 100755 index 72d55134e..070308070 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1,4 +1,6 @@ import math +import os +import random import pytest import torch @@ -17,6 +19,17 @@ from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb + +# Test ROCM Triton Backend +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + random.seed(42) + +def skip_config(**kwargs): + if 'd' in kwargs: + return random.random() < 0.20 + return False + MAX_HEADDIM_SM8x = 192 @@ -584,6 +597,18 @@ def get_dropout_fraction( @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): + if USE_TRITON_ROCM: + test_backward = False + + if dropout_p != 0.0: + pytest.skip("Dropout not supported in AMD's Triton Backend yet") + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen=seqlen, d=d): + pytest.skip("Skipping configuration due to limited test time") + if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -686,7 +711,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ # do_o = (g.float() * out.float()).sum(-1) # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) + if test_backward: (dqkv,) = torch.autograd.grad(out, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) @@ -709,7 +735,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if test_backward: assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -733,6 +759,18 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ def test_flash_attn_varlen_qkvpacked( seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype ): + if USE_TRITON_ROCM: + test_backward = False + + if dropout_p != 0.0: + pytest.skip("Dropout not supported in AMD's Triton Backend yet") + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen=seqlen, d=d): + pytest.skip("Skipping configuration due to limited test time") + if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -833,7 +871,8 @@ def test_flash_attn_varlen_qkvpacked( print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) + if test_backward: (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) dqkv = dqkv_pad_fn(dqkv_unpad) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) @@ -857,7 +896,7 @@ def test_flash_attn_varlen_qkvpacked( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if test_backward: assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -903,6 +942,21 @@ def test_flash_attn_varlen_qkvpacked( def test_flash_attn_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): + if USE_TRITON_ROCM: + test_backward = False + + if dropout_p != 0.0: + pytest.skip("Dropout not supported on AMD's Triton Backend yet") + + if softcap != 0.0: + pytest.skip("softcap not supported on AMD's Triton Backend yet") + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + pytest.skip("Skipping configuration due to limited test time") + if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1070,7 +1124,8 @@ def test_flash_attn_output( g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)) + if test_backward: if kvpacked: ( dq, @@ -1126,7 +1181,7 @@ def test_flash_attn_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if test_backward: assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @@ -1172,6 +1227,22 @@ def test_flash_attn_output( def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): + + if USE_TRITON_ROCM: + test_backward = False + + if dropout_p != 0.0: + pytest.skip("Dropout not supported in AMD's Triton Backend yet") + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if softcap != 0.0: + pytest.skip("softcap not supported on AMD's Triton Backend yet") + + if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + pytest.skip("Skipping configuration due to limited test time") + if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1386,7 +1457,8 @@ def test_flash_attn_varlen_output( print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) - if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)): + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)) + if test_backward: if kvpacked: ( dq_unpad, @@ -1445,7 +1517,7 @@ def test_flash_attn_varlen_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if test_backward: assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @@ -1480,6 +1552,15 @@ def test_flash_attn_varlen_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): + if USE_TRITON_ROCM: + test_backward = False + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + pytest.skip("Skipping configuration due to limited test time") + if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1523,41 +1604,44 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) - ( - dq, - dk, - dv, - ) = torch.autograd.grad(out, (q, k, v), g) - ( - dq_ref, - dk_ref, - dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) - ( - dq_pt, - dk_pt, - dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90)) + if test_backward: + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 - assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + if test_backward: + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1593,6 +1677,21 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): def test_flash_attn_varlen_causal( seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype ): + if USE_TRITON_ROCM: + test_backward = False + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if paged_kv_block_size is not None: + pytest.skip("paged attention not supported on AMD's Triton Backend yet") + + if seqlen_q * seqlen_k >= 256 * 512: + pytest.skip(f"{seqlen_q}, {seqlen_k} leads to out of memory on AMD") + + if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + pytest.skip("Skipping configuration due to limited test time") + if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1686,7 +1785,7 @@ def test_flash_attn_varlen_causal( g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) - test_backward = block_table is None + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None) if test_backward: ( dq_unpad, @@ -1765,6 +1864,16 @@ def test_flash_attn_varlen_causal( def test_flash_attn_splitkv( seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype ): + + if USE_TRITON_ROCM: + test_backward = False + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + pytest.skip("Skipping configuration due to limited test time") + if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" @@ -1817,79 +1926,92 @@ def test_flash_attn_splitkv( g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) - ( - dq, - dk, - dv, - ) = torch.autograd.grad(out, (q, k, v), g) - ( - dq_ref, - dk_ref, - dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) - ( - dq_pt, - dk_pt, - dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90)) + if test_backward: + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 mult = 2 if not alibi else 8 - assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 - assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 - assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 + if test_backward: + assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 + assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 + assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) -# @pytest.mark.parametrize("num_splits", [1]) +# @pytest.mark.parametrize("num_splits", [0]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("new_kv", [False, True]) +# @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [False]) +@pytest.mark.parametrize("new_kv", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) -@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) @pytest.mark.parametrize("rotary_interleaved", [False, True]) # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -# @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) +# @pytest.mark.parametrize("rotary_fraction", [0.5, 1.0]) +# @pytest.mark.parametrize("rotary_fraction", [1.0]) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) -# @pytest.mark.parametrize("paged_kv_block_size", [None]) +@pytest.mark.parametrize("paged_kv_block_size", [None]) @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [True]) +# @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) -@pytest.mark.parametrize("has_batch_idx", [False]) -@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("has_batch_idx", [True]) +# @pytest.mark.parametrize("d", [2, 8, 16, 32, 59, 64, 80, 128, 256]) +@pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [16]) # 16 fails +# @pytest.mark.parametrize("d", [2]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ + # (1, 1), + # (1, 2), + # (2, 2), + # (4, 4), + # (1, 4), (1, 128), (1, 339), (3, 1024), @@ -1898,12 +2020,13 @@ def test_flash_attn_splitkv( (3, 799), (64, 2048), (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), - (128, 128), + # (1, 128 * 1024), + # (16, 128 * 1024), + # (128, 128), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize('DEBUG_ENABLED', [False]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, @@ -1921,7 +2044,24 @@ def test_flash_attn_kvcache( mha_type, num_splits, dtype, + DEBUG_ENABLED ): + if USE_TRITON_ROCM: + if paged_kv_block_size is not None: + pytest.skip("paged attention not supported on AMD's Triton Backend yet") + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + # if rotary_interleaved == True or rotary_fraction > 0.0: + # pytest.skip("rotary embedding not supported on AMD's Triton Backend yet") + + if has_leftpad == True: + pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet") + + # if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + # pytest.skip("Skipping configuration due to limited test time") + if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: @@ -1930,27 +2070,46 @@ def test_flash_attn_kvcache( pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() + device = "cuda" # set seed torch.random.manual_seed(0) - batch_size = 2 + batch_size = 1 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 - nheads = 6 + nheads = 1 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) - assert nheads % nheads_k == 0 + + # in case of GCA and nhead is < 3 skip since nheads_k (group size) will be 3 and you don't have enough heads for a single group + if mha_type == "gqa" and nheads < 3: + pytest.skip() + + assert nheads % nheads_k == 0, "num heads cannot be evenly split into groups" window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + + if DEBUG_ENABLED: + q = torch.arange(seqlen_q, dtype=dtype, device="cuda").view(1, seqlen_q, 1, 1).expand(batch_size, seqlen_q, nheads, d).requires_grad_().contiguous().to(dtype=dtype) + else: + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + if DEBUG_ENABLED or True: + k = torch.arange(seqlen_new, dtype=dtype, device="cuda").view(1, seqlen_new, 1, 1).expand(batch_size, seqlen_new, nheads_k, d).requires_grad_().contiguous().to(dtype=dtype) + v = torch.arange(seqlen_new, dtype=dtype, device="cuda").view(1, seqlen_new, 1, 1).expand(batch_size, seqlen_new, nheads_k, d).requires_grad_().contiguous().to(dtype=dtype) + else: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None if paged_kv_block_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + if DEBUG_ENABLED: + k_cache = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, seqlen_k, 1, 1).expand(batch_size_cache, seqlen_k, nheads_k, d).requires_grad_().contiguous() + v_cache = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, seqlen_k, 1, 1).expand(batch_size_cache, seqlen_k, nheads_k, d).requires_grad_().contiguous() + else: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) block_table = None else: ( @@ -2060,13 +2219,18 @@ def test_flash_attn_kvcache( v, rotary_cos=cos, rotary_sin=sin, + rotary_cos_k=None, + rotary_sin_k=None, + rotary_interleaved=rotary_interleaved, + rotary_inplace=False, + rotary_conjugate=False, cache_seqlens=cache_seqlens, cache_batch_idx=cache_batch_idx, cache_leftpad=cache_leftpad, block_table=block_table, causal=causal, + local=local, window_size=window_size, - rotary_interleaved=rotary_interleaved, alibi_slopes=alibi_slopes, num_splits=num_splits, ) @@ -2197,6 +2361,15 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): + if USE_TRITON_ROCM: + test_backward = False + + if dropout_p != 0.0: + pytest.skip("Dropout not supported in AMD's Triton Backend yet") + + if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + pytest.skip("Skipping configuration due to limited test time") + device = "cuda" # set seed torch.random.manual_seed(0) @@ -2208,7 +2381,8 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty torch.random.manual_seed(42) out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) g = torch.randn_like(out0) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) + if test_backward: ( dq0, dk0, @@ -2223,7 +2397,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty assert torch.equal(out, out0) assert torch.equal(lse, lse0) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if test_backward: ( dq, dk, @@ -2248,6 +2422,13 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0. """ + if USE_TRITON_ROCM: + if True: + pytest.skip("Backward Attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen=seqlen, d=d): + pytest.skip("Skipping configuration due to limited test time") + device = "cuda" # set seed torch.random.manual_seed(0) @@ -2304,6 +2485,13 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): """We previously had a bug where we were using the wrong strides of dout, which shows up when dout is not contiguous. """ + if USE_TRITON_ROCM: + if True: + pytest.skip("Backward Attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen=seqlen, d=d): + pytest.skip("Skipping configuration due to limited test time") + device = "cuda" # set seed torch.random.manual_seed(0) @@ -2356,6 +2544,13 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0 or varlen. """ + if USE_TRITON_ROCM: + if True: + pytest.skip("Backward Attention not supported on AMD's Triton Backend yet") + + if skip_config(d=d): + pytest.skip("Skipping configuration due to limited test time") + device = "cuda" # set seed torch.random.manual_seed(0) @@ -2411,6 +2606,15 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): + if USE_TRITON_ROCM: + test_backward = False + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + pytest.skip("Skipping configuration due to limited test time") + if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -2430,12 +2634,14 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) g = torch.randn_like(out) - dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) - for _ in range(50): - dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) - assert torch.equal(dv, dv0) - assert torch.equal(dk, dk0) - assert torch.equal(dq, dq0) + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90)) + if test_backward: + dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) + for _ in range(50): + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -2469,6 +2675,15 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc ) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): + if USE_TRITON_ROCM: + test_backward = False + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): + pytest.skip("Skipping configuration due to limited test time") + if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -2517,9 +2732,11 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus ) g = torch.randn_like(out) - dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) - for _ in range(50): - dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) - assert torch.equal(dv, dv0) - assert torch.equal(dk, dk0) - assert torch.equal(dq, dq0) + test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90)) + if test_backward: + dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + for _ in range(50): + dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0) diff --git a/tests/test_precision_error.py b/tests/test_precision_error.py new file mode 100644 index 000000000..a8bc9edf8 --- /dev/null +++ b/tests/test_precision_error.py @@ -0,0 +1,205 @@ +import torch +import triton +import triton.language as tl +import pytest +import pdb + +@triton.jit +def many_ops_triton(x_ptr, + y_ptr, + o_ptr, + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + mult: tl.constexpr, + IMITATE_PYTORCH: tl.constexpr, + DTYPE: tl.constexpr, + DO_MULTIPLY: tl.constexpr, + DO_SIGMOID: tl.constexpr, + DO_COS: tl.constexpr, + DO_EXPONENT: tl.constexpr, + DO_SQRT: tl.constexpr + ): + """ + x_ptr: pointer to an (M, K) tensor [input] + y_ptr: pointer to an (K, N) tensor [input] + + o_ptr: pointer to an (M, N) tensor [output] + + M: int matrix shape + K: int matrix shape + N: int matrix shape + + mult: multiplication factor for multiplication operation + + IMITATE_PYTORCH: { + 0: no casting after ops, + 1: cast to original dtype after every op + } + DTYPE: { + 0: fp16, + 1: fp32, + 2: fp64 + } + """ + # Set input dtype (we will cast back to this for the output) + input_dtype = tl.float16 if DTYPE==0 else tl.float32 if DTYPE==1 else None + + x_block_range = tl.arange(0, M)[:, None]*K + tl.arange(0, K)[None, :] + y_block_range = tl.arange(0, K)[:, None]*N + tl.arange(0, N)[None, :] + x = tl.load(x_ptr + x_block_range) + y = tl.load(y_ptr + y_block_range) + + # Multiply + if DO_MULTIPLY: + x = x * mult + y = y * mult + if IMITATE_PYTORCH: + x = x.to(input_dtype) + y = y.to(input_dtype) + + # Sigmoid + if DO_SIGMOID: + x = tl.sigmoid(x + 0.0) # +0.0 cause tl.sigmoid requires a fp32 and 0.0 is fp32 by default so if dtype if fp16 will become fp32 + y = tl.sigmoid(y + 0.0) + if IMITATE_PYTORCH: + x = x.to(input_dtype) + y = y.to(input_dtype) + + # Cos + if DO_COS: + x = tl.cos(x + 0.0) # +0.0 because requires fp32 or fp64 + y = tl.cos(y + 0.0) + if IMITATE_PYTORCH: + x = x.to(input_dtype) + y = y.to(input_dtype) + + # Exponentiate + if DO_EXPONENT: + log2_e = 1.4426950408889634 # log2(e) + x = tl.exp2(log2_e * x) + y = tl.exp2(log2_e * y) + if IMITATE_PYTORCH: + x = x.to(input_dtype) + y = y.to(input_dtype) + + # Sqrt + if DO_SQRT: + x = tl.sqrt(x + 0.0) # +0.0 because requires fp32 or fp64 + y = tl.sqrt(y + 0.0) + if IMITATE_PYTORCH: + x = x.to(input_dtype) + y = y.to(input_dtype) + + # Matmul + o_block_range = tl.arange(0, M)[:, None]*N + tl.arange(0, N)[None, :] + o = tl.dot(x, y) # tl.dot always outputs input dtype. ALSO REQUIRES INPUT SHAPES M >= 16, N >= 16 and K >= 16 + if IMITATE_PYTORCH: + x = x.to(input_dtype) + y = y.to(input_dtype) + + # o = tl.dot(x, y, out_dtype=input_dtype) # FUSE CAST INTO DOT + + tl.store(o_ptr + o_block_range, o) + +def many_ops_torch(x: torch.Tensor, + y: torch.Tensor, + out: torch.Tensor, + M: int, + K: int, + N: int, + mult: float, + DO_MULTIPLY: bool, + DO_SIGMOID: bool, + DO_COS: bool, + DO_EXPONENT: bool, + DO_SQRT: bool + ): + + # Multiply + if DO_MULTIPLY: + x = x * mult + y = y * mult + + # Sigmoid + if DO_SIGMOID: + x = torch.sigmoid(x) + y = torch.sigmoid(y) + + # Cos + if DO_COS: + x = torch.cos(x) + y = torch.cos(y) + + # Exponentiate + if DO_EXPONENT: + x = torch.exp(x) + y = torch.exp(y) + + # Sqrt + if DO_SQRT: + x = torch.sqrt(x) + y = torch.sqrt(y) + + # Matmul + out[:] = torch.matmul(x, y) # stores in place + +@pytest.mark.parametrize("seed", [i for i in range(1)]) # seed for rand num generator +@pytest.mark.parametrize("M", [16, 32]) +@pytest.mark.parametrize("K", [16, 32, 64]) # 64 seems to cause some issues +@pytest.mark.parametrize("N", [16, 32]) +@pytest.mark.parametrize("mult", [0.001, 1.5251]) # mult = [0, 2.99] +@pytest.mark.parametrize("IMITATE_PYTORCH", [1]) # 0 = no casting (not imitating pytorch), 1 = cast after every op (imitating pytorch) +@pytest.mark.parametrize("DTYPE", [0]) # 0 = fp16, 1 = fp32 +@pytest.mark.parametrize("DO_MULTIPLY", [0, 1]) # Include multiplication +@pytest.mark.parametrize("DO_SIGMOID", [0, 1]) # Include sigmoid +@pytest.mark.parametrize("DO_COS", [0, 1]) # Include cosine +@pytest.mark.parametrize("DO_EXPONENT", [0, 1]) # Include exponentiation +@pytest.mark.parametrize("DO_SQRT", [0, 1]) # Include square root +def test_many_ops(seed, M, K, N, mult, IMITATE_PYTORCH, DTYPE, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT): + """ + Test reproducability of PyTorch results with a Triton kernel implementing various math operations. + + Each operation can be individually enabled or disabled using the respective parameters. The test will compare + the results from Triton and PyTorch to ensure they match within a specified tolerance. + + Args: + seed (int): Random seed for reproducibility. + M (int): Number of rows for the first input tensor. + K (int): Number of columns for the first input tensor and rows for the second. + N (int): Number of columns for the second input tensor. + mult (float): Multiplication factor for the input tensors. + IMITATE_PYTORCH (int): If 1, cast tensors back to their original dtype after each operation, if 0 does not cast until very end. + DTYPE (int): Data type of the input tensors (0 for fp16, 1 for fp32). + DO_MULTIPLY (int): If 1, include multiplication in the operations, if 0 does not. + DO_SIGMOID (int): If 1, include sigmoid activation in the operations, if 0 does not. + DO_COS (int): If 1, include cosine transformation in the operations, if 0 does not. + DO_EXPONENT (int): If 1, include exponentiation in the operations, if 0 does not. + DO_SQRT (int): If 1, include square root in the operations, if 0 does not. + """ + + # Misc parameters + torch.set_printoptions(precision=6) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + torch.manual_seed(seed) + + input_dtype = torch.float16 if DTYPE==0 else torch.float32 if DTYPE==1 else None + + x = torch.rand(M, K, dtype=input_dtype, device=device) + y = torch.rand(K, N, dtype=input_dtype, device=device) + + grid = (1,) + out = torch.zeros(M, N, dtype=input_dtype, device=device) + out_torch = torch.zeros(M, N, dtype=input_dtype, device=device) + + with torch.cuda.device(x.device): + many_ops_triton[grid](x, y, out, M, K, N, mult, IMITATE_PYTORCH, DTYPE, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT) + many_ops_torch(x, y, out_torch, M, K, N, mult, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT) + + print("torch", out_torch) + print("out", out) + + print("torch - out", (out_torch-out)) + + assert torch.allclose(out_torch, out, atol=0) # tensors must match exactly \ No newline at end of file