Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue]: large bf16 inputs leads to nan #54

Open
xinyazhang opened this issue Nov 7, 2024 · 0 comments
Open

[Issue]: large bf16 inputs leads to nan #54

xinyazhang opened this issue Nov 7, 2024 · 0 comments

Comments

@xinyazhang
Copy link
Collaborator

Problem Description

This is due to the core logic of FA algorithm:

# softmax
m_ij = tl.maximum(m_i, qk_scale * tl.max(qk, 1))
p = tl.math.exp2(qk * qk_scale - m_ij[:, None])

qk_scale * tl.max(qk, 1)) uses vmul, but qk * qk_scale - m_ij[:, None] uses vfma
vfma has higher precision than vmul+vadd because it only round once.

More specifically, Suppose the vectors all only contain one element, and qk = round(133120.0 * 133120.0) = 17720934400.0, qk_scale = round(0.25 * 1.44269502162933349609) = 0.360673755407333374023
Then

qk * qk_scale - m_ij
= qk * qk_scale - qk_scale * qk
= fma(qk, qk_scale, -vmul(qk, qk_scale))
= fma(qk, qk_scale, -round(17720934400.0, 0.360673755407333374023))
= round(qk * qk_scale - 6391475712.0)
= round(17720934400.0 * 0.360673755407333374023 - round(17720934400.0, 0.360673755407333374023))
= round(6391475959.3749999999922470912 - 6391475712.0)
= 247.375

Therefore, p in the code above yields to exp2(247.375)=inf. This consequently leads to nan in following steps.

Operating System

N/A

CPU

N/A

GPU

MI300X

ROCm Version

ROCm 6.2.3

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant