You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
qk_scale * tl.max(qk, 1)) uses vmul, but qk * qk_scale - m_ij[:, None] uses vfma vfma has higher precision than vmul+vadd because it only round once.
More specifically, Suppose the vectors all only contain one element, and qk = round(133120.0 * 133120.0) = 17720934400.0, qk_scale = round(0.25 * 1.44269502162933349609) = 0.360673755407333374023
Then
Problem Description
This is due to the core logic of FA algorithm:
aotriton/tritonsrc/fwd_kernel_inner.py
Lines 95 to 97 in f6b28a9
qk_scale * tl.max(qk, 1))
usesvmul
, butqk * qk_scale - m_ij[:, None]
usesvfma
vfma
has higher precision thanvmul
+vadd
because it only round once.More specifically, Suppose the vectors all only contain one element, and
qk = round(133120.0 * 133120.0) = 17720934400.0
,qk_scale = round(0.25 * 1.44269502162933349609) = 0.360673755407333374023
Then
Therefore,
p
in the code above yields toexp2(247.375)=inf
. This consequently leads to nan in following steps.Operating System
N/A
CPU
N/A
GPU
MI300X
ROCm Version
ROCm 6.2.3
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
The text was updated successfully, but these errors were encountered: