Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Update XLA-fp8 docs for ninja user guide #478

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion rosetta/docs/NATIVE_FP8.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,42 @@ python -m paxml.main \

Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions.

### Transformer Engine vs Native FP8 Support
## Guide for Ninja Users

### Exact pattern that XLA can match for FP8 MatMul
The specific graph pattern that XLA supports for FP8 matmul is illustrated below:

```
convert -> multiply -> (x) -> dot
bcast ->
nouiz marked this conversation as resolved.
Show resolved Hide resolved
```

XLA will pattern match the above and rewrite it to FP8 matmul when:
1. `convert`: supports `f8` to [`bf16`|`f16`|`f32`].
2. `bcast`: broadcasts a [`bf16`|`f16`|`f32`] scalar.
nouiz marked this conversation as resolved.
Show resolved Hide resolved
3. `(x)`: an arbitrary number of allowed ops (e.g., all-gather, copy, bitcast...), the full list of which can be found [here](https://github.com/openxla/xla/blob/e4fc3298eefa91702f86068af45340cad78d0335/xla/service/gpu/gemm_rewriter.cc#L259-L265).
nouiz marked this conversation as resolved.
Show resolved Hide resolved

### Gradient accumulation of FP8 params
FP8 params, also known as OWG params (or FP8 meta), may be shared across different iterations of a loop in the context of pipeline parallelism. During backpropagation, the autograd system accumulates their gradients from each iteration through the default addition operation. This is undesirable as addition is meaningless for FP8 params.
nouiz marked this conversation as resolved.
Show resolved Hide resolved

To address this, we introduce a custom dtype wrapper `fm32` (means fp8 meta with the 32-bit physical size). It tells the autograd system to perform the max operation for gradient accumulation. This aligns with our expectations for FP8 params. The basic usage is demonstrated below:
nouiz marked this conversation as resolved.
Show resolved Hide resolved

```python
def outer_fn(scale_f32, ...):
# Convert fp8 meta f32->fm32 before the scan_loop
scale_fm32 = jax.lax.convert_element_type(scale_f32, fp8_ops.fm32)

def body_fn(carry, _):
# Can temperarily convert scale_fm32 back to f32 for general math operations
return carry, None

jax.lax.scan(body_fun, ..., length=3)
return ...
nouiz marked this conversation as resolved.
Show resolved Hide resolved
```

In the example, we need to convert the FP8 params (e.g. the scale) from the original `f32` to `fm32` before launching the scan loop so that the autograd can apply the correct grad accumulation between loop iterations. Inside each iteration (i.e. `body_fn`), we can convert them from `fm32` to `f32` for general math operations (e.g. `mul`, `div`, etc.) and convert back to `fm32` at exit.

## Transformer Engine vs Native FP8 Support
nouiz marked this conversation as resolved.
Show resolved Hide resolved
Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components.

In practical terms, XLA-FP8 performs pattern matching and rewrites the matrix multiplication operations in the operation graph to utilize FP8 matrix multiplication. On the other hand, with TE, the [entire Praxis transformer](https://github.com/google/praxis/blob/main/praxis/layers/transformers.py) layer will be substituted with our [Transformer Engine
Expand Down