-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pyTorch] Infrastructure for C++ QuantizedTensor #1251
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Przemyslaw Tredak <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Przemyslaw Tredak <[email protected]>
for more information, see https://pre-commit.ci
def general_gemm( | ||
A: Union[torch.Tensor, Float8Tensor], | ||
B: Union[torch.Tensor, Float8Tensor], | ||
workspace: torch.Tensor, | ||
gelu: bool = False, | ||
accumulate: bool = False, | ||
out: Optional[torch.Tensor] = None, | ||
bias: Optional[torch.Tensor] = None, | ||
use_split_accumulator: bool = False, | ||
D_dtype: Optional[tex.DType] = None, | ||
ub_algo: tex.UbufOverlapAlgo = None, | ||
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None, | ||
extra_output_tensor: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we're reworking this API, perhaps we should call it matmul
since it's less ambiguous (e.g. with column-major/row-major order). We should also keep the core API simple like torch.matmul
and np.matmul
, and leave our non-standard options as kwargs:
def matmul(
A: torch.Tensor, # maybe QuantizedTensor
B: torch.Tensor, # maybe QuantizedTensor
/,
out: Optional[torch.Tensor] = None, # maybe QuantizedTensor
*,
transa: bool = False,
transb: bool = False,
out_dtype: Optional[tex.DType] = None,
accumulate_out: bool = False, # alternatively: alpha and beta
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None, # more general than gelu
workspace: torch.Tensor, # maybe allocate in C++ if not provided
use_split_accumulator: bool = False, # maybe hide within cublas_options kwarg?
userbuffers_options: Optional[dict] = None, # minimize impact of unstable UB API
) -> torch.Tensor: # maybe QuantizedTensor
Signed-off-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Przemyslaw Tredak <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could clean up the API for the quantization params. As I understand it, we currently have three classes:
QuantizationParams
: Holds state needed for quantization (scale, amax). It is used in C++ functions.QMeta
: Builder class forQuantizedTensor
that holds an FP8 recipe. It mainly constructs a defaultQuantizationParams
and callsQuantizedTensor.quantize
.QuantizationParamsProxy
: Holdsfp8_meta
and converts it toQuantizationParams
.
I propose unifying these APIs into a builder class:
class Quantizer(abc.ABC):
@abc.abstractmethod
def quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
...
class Float8Quantizer(Quantizer):
def __init__(self, scale, amax, fp8_dtype, dtype, roles, ...):
self.scale = scale
...
def quantize(self, tensor):
return tex.cast(tensor, self)
# Maybe a direction for the future
# def update_(self, recipe):
# tex.fused_amax_and_scale_update_after_reduction(...)
class FP8MetaQuantizer(Quantizer):
"""Proxy of Float8Quantizer that support fp8_meta dicts"""
def __init__(self, fp8_meta, fp8_meta_index, ...):
...
def quantize(self, tensor):
quantizer = Float8Quantizer(fp8_meta[...].scale, ...)
return quantizer.quantize(tensor)
My thought is for Quantizer
to replace fp8_meta
over time. Instead of modules holding a complicated dict
with random tensors, it can just hold a few Quantizer
s (for input, params, grad output, etc). Instead of fp8_autocast
needing to dig through the fp8_meta
s to update amax_history
and scale
, we can encapsulate it in an update_
function. If we go in this direction, we'll eventually be able to get rid of the fp8_meta
s and FP8MetaQuantizer
entirely.
Description
The goal of this PR is to bring the QuantizedTensor (e.g. Float8Tensor) closer to the C++ level in order to minimize overheads, while keeping the same functionality and Pythonic nature of it. It also includes changes to the GEMM call (once closer to completeness most probably going to be taken out as separate PR) to test the functionality.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: