Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyTorch] Infrastructure for C++ QuantizedTensor #1251

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Oct 14, 2024

Description

The goal of this PR is to bring the QuantizedTensor (e.g. Float8Tensor) closer to the C++ level in order to minimize overheads, while keeping the same functionality and Pythonic nature of it. It also includes changes to the GEMM call (once closer to completeness most probably going to be taken out as separate PR) to test the functionality.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Comment on lines 30 to 43
def general_gemm(
A: Union[torch.Tensor, Float8Tensor],
B: Union[torch.Tensor, Float8Tensor],
workspace: torch.Tensor,
gelu: bool = False,
accumulate: bool = False,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
ub_algo: tex.UbufOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
extra_output_tensor: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Copy link
Collaborator

@timmoon10 timmoon10 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're reworking this API, perhaps we should call it matmul since it's less ambiguous (e.g. with column-major/row-major order). We should also keep the core API simple like torch.matmul and np.matmul, and leave our non-standard options as kwargs:

def matmul(
    A: torch.Tensor,  # maybe QuantizedTensor
    B: torch.Tensor,  # maybe QuantizedTensor
    /,
    out: Optional[torch.Tensor] = None,  # maybe QuantizedTensor
    *,
    transa: bool = False,
    transb: bool = False,
    out_dtype: Optional[tex.DType] = None,
    accumulate_out: bool = False,  # alternatively: alpha and beta
    bias: Optional[torch.Tensor] = None,
    activation: Optional[str] = None,  # more general than gelu
    workspace: torch.Tensor,  # maybe allocate in C++ if not provided
    use_split_accumulator: bool = False,  # maybe hide within cublas_options kwarg?
    userbuffers_options: Optional[dict] = None,  # minimize impact of unstable UB API
) -> torch.Tensor:  # maybe QuantizedTensor

Signed-off-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Przemyslaw Tredak <[email protected]>
Copy link
Collaborator

@timmoon10 timmoon10 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could clean up the API for the quantization params. As I understand it, we currently have three classes:

  • QuantizationParams: Holds state needed for quantization (scale, amax). It is used in C++ functions.
  • QMeta: Builder class for QuantizedTensor that holds an FP8 recipe. It mainly constructs a default QuantizationParams and calls QuantizedTensor.quantize.
  • QuantizationParamsProxy: Holds fp8_meta and converts it to QuantizationParams.

I propose unifying these APIs into a builder class:

class Quantizer(abc.ABC):
    @abc.abstractmethod
    def quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
        ...

class Float8Quantizer(Quantizer):
    def __init__(self, scale, amax, fp8_dtype, dtype, roles, ...):
        self.scale = scale
        ...

    def quantize(self, tensor):
        return tex.cast(tensor, self)

    # Maybe a direction for the future
    # def update_(self, recipe):
    #     tex.fused_amax_and_scale_update_after_reduction(...)

class FP8MetaQuantizer(Quantizer):
    """Proxy of Float8Quantizer that support fp8_meta dicts"""
    def __init__(self, fp8_meta, fp8_meta_index, ...):
        ...

    def quantize(self, tensor):
        quantizer = Float8Quantizer(fp8_meta[...].scale, ...)
        return quantizer.quantize(tensor)

My thought is for Quantizer to replace fp8_meta over time. Instead of modules holding a complicated dict with random tensors, it can just hold a few Quantizers (for input, params, grad output, etc). Instead of fp8_autocast needing to dig through the fp8_metas to update amax_history and scale, we can encapsulate it in an update_ function. If we go in this direction, we'll eventually be able to get rid of the fp8_metas and FP8MetaQuantizer entirely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants