Skip to content

Commit

Permalink
Minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed May 30, 2024
1 parent 54add36 commit 320cb1c
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions rosetta/docs/NATIVE_FP8.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ XLA will pattern match the above and rewrite it to FP8 matmul when:
3. `(x)`: an arbitrary number of allowed ops (e.g., all-gather, copy, bitcast...), the full list of which can be found [here](https://github.com/openxla/xla/blob/e4fc3298eefa91702f86068af45340cad78d0335/xla/service/gpu/gemm_rewriter.cc#L259-L265).

### Gradient accumulation of FP8 params
FP8 params, also known as OWG params (or FP8 meta), may be shared across different iterations of a loop in the context of pipeline parallelism. Consequently, the autograd system accumulates their gradients through addition operations. This is undesirable as addition is meaningless for FP8 params.
FP8 params, also known as OWG params (or FP8 meta), may be shared across different iterations of a loop in the context of pipeline parallelism. During backpropagation, the autograd system accumulates their gradients from each iteration through the default addition operation. This is undesirable as addition is meaningless for FP8 params.

To address this, we introduce a custom dtype wrapper `fm32` (means fp8 meta with the 32-bit physical size). It tells the autograd system to perform the max operation for gradient accumulation. This aligns with our expectations for FP8 params. The basic usage is demonstrated below:

Expand All @@ -154,15 +154,14 @@ def outer_fn(scale_f32, ...):
scale_fm32 = jax.lax.convert_element_type(scale_f32, fp8_ops.fm32)

def body_fn(carry, _):
# use scale_fm32; can temperarily convert it back to f32 for general
# math operations
# Can temperarily convert scale_fm32 back to f32 for general math operations
return carry, None

jax.lax.scan(body_fun, ..., length=3)
return ...
```

The main point is that we need to convert the FP8 params (e.g. the scale) from the original `f32` to `fm32` before launching the scan loop so that the autograd can apply the correct grad accumulation between loop iterations. Inside each iteration (i.e. `body_fn`), we can convert them from `fm32` to `f32` for general math operations (e.g. `mul`, `div`, etc.) and convert back to `fm32` at exit.
In the example, we need to convert the FP8 params (e.g. the scale) from the original `f32` to `fm32` before launching the scan loop so that the autograd can apply the correct grad accumulation between loop iterations. Inside each iteration (i.e. `body_fn`), we can convert them from `fm32` to `f32` for general math operations (e.g. `mul`, `div`, etc.) and convert back to `fm32` at exit.

## Transformer Engine vs Native FP8 Support
Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components.
Expand Down

0 comments on commit 320cb1c

Please sign in to comment.