diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index 72057b411..cb26ea2df 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -128,9 +128,68 @@ python -m paxml.main \ Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. -### Transformer Engine vs Native FP8 Support + +## Transformer Engine vs Native FP8 Support Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components. In practical terms, XLA-FP8 performs pattern matching and rewrites the matrix multiplication operations in the operation graph to utilize FP8 matrix multiplication. On the other hand, with TE, the [entire Praxis transformer](https://github.com/google/praxis/blob/main/praxis/layers/transformers.py) layer will be substituted with our [Transformer Engine layer](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.flax.TransformerLayer), offering a comprehensive performance enhancement. + +## Guide for Ninja Users + +### Exact pattern that XLA can match for FP8 MatMul +The specific graph pattern that XLA supports for FP8 matmul is illustrated below: + +``` + convert -> multiply -> (x) -> dot -> (y) +broadcast -> + +# or + + convert -> (x) -> dot -> (y) +``` + +XLA will pattern match the above and rewrite it to FP8 matmul when: +1. `convert`: converts `f8` inputs to [`bf16`|`f16`|`f32`]. +2. `broadcast`: broadcasts a [`bf16`|`f16`|`f32`] scalar. The scalar will be used as the scaling factor of the inputs. Note, the `convert` and `broadcast` need to have the same output dtype. If `broadcast` (and `multiply`) is not provided, the scaling factor will be set to `1.`. +3. `(x)`: an arbitrary number of these allowed ops: +``` +Bitcast, Broadcast, Copy, DynamicSlice, Pad, Reshape, Select, Slice, Transpose, +AllGather, AllToAll, CollectivePermute +``` +4. `dot`: the inputs and outputs are both [`bf16`|`f16`|`f32`]. +5. `y`: supports epilog fusions, like `add` (optional). + +### Gradient accumulation of FP8 params +FP8 params, also known as `OverwriteWithGrad` params (or FP8 meta), may be shared across different iterations of a loop in the context of pipeline parallelism. During backpropagation, the autograd system accumulates their gradients from each iteration through the default addition operation. This is undesirable as addition is meaningless for FP8 params. + +To address this, we introduce a custom dtype wrapper `fp32_max_grad`. It tells the autograd system to perform the max operation for gradient accumulation. This aligns with our expectations for FP8 params. The basic usage is demonstrated below: + +```python +from flax.linen import fp8_ops +f32 = jnp.float32 +fmax32 = fp8_ops.fp32_max_grad +def outer(x, ah_f32, sf_f32): + # ah and sf are FP8 params and short for amax history and scaling factor + # respectively. + ah_fmax32 = jax.lax.convert_element_type(ah_f32, fmax32) + sf_fmax32 = jax.lax.convert_element_type(sf_f32, fmax32) + array_x = jnp.array([x], f32) + def body_fn(carry, _): + carry = fp8_ops.in_qdq(f32, carry, sf_fmax32, ah_fmax32) + return carry, None + array_x, _ = jax.lax.scan(body_fn, array_x, None, length=3) + return array_x[0] + +outer_fn = jax.grad(outer, (0, 1, 2)) +outer_fn = jax.jit(outer_fn) + +ah = jnp.array([0., 0., 0.], f32) +sf = jnp.array([1.], f32) +grads, new_ah, new_sf = outer_fn(2.0, ah, sf) +``` + +In the example, we convert the FP8 params from the original `f32` to `fp32_max_grad` before launching the scan loop so that the autograd can apply the correct grad accumulation between loop iterations. Inside each iteration (i.e. `body_fn`), we can operate them by, for example, calling `fp8_ops.in_qdq()` where internally they will be converted back to `f32` for general math operations (e.g. `mul`, `div`, etc.) and convert to `fp32_max_grad` at exit. + +