[PyTorch] Add heuristics for intializing FP8 params #1300
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
#1127 changed the default behavior of
Float8Tensor
so that the FP8 transpose is no longer memoized, but rather treated as part of the data layout. This helps us avoid headaches where the cache is invalidated (since the FP8 data and transpose are updated together inFloat8Tensor.quantize_
) and reduces the FP8-specific logic in the TE modules (e.g. whether to update the transpose in the first microbatch). However, it also means thatFloat8Tensor
needs to know whether to store the FP8 transpose at construction time.This PR adds a
heuristic
kwarg to thefp8_model_init
context. Right now the only supported value is"memory"
, which will cause TE modules to initialize FP8 params without FP8 transposes. In the future, I can imagine extending this logic tofp8_autocast
so it can handle FP8 data tensors. We could also support other heuristics (e.g."performance"
,"communication"
, etc) and add logic for other quantization schemes. However, I've purposely kept this PR small to avoid interfering with other changes in the quantization infrastructure (see #1251).Type of change
Changes
fp8_model_init
contextChecklist: