Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Collective GEMM custom op with nvte_cublas_gemm (no comm. overlap) #1307

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

denera
Copy link
Collaborator

@denera denera commented Nov 2, 2024

Description

Implements both old-style and new FFI-based XLA custom calls in C++, and the corresponding JAX primitive including custom partitioning rules.

Custom partitioning rules for a LHS:([B,] M, K) x RHS:([B,] K, N) = OUT:([B,] M, N) batched mat-mul operation where [B] is the batch dimension:

  • Preserve the partitioning of the [B] dimension for all operands.
  • Always all-gather LHS along the M dimension.
  • Error out if RHS is partitioned in both K and N dimensions.
  • Force the K dimension of LHS to match the partitioning of the K dimension of RHS.
  • If K dimension is partitioned but M dimension is not, jax.lax.psum (all-reduce) the output over the TP mesh resource.
  • If both the M and K dimensions are partitioned, jax.lax.psum_scatter (reduce-scatter) the output over the TP mesh resource.

In practice, the RHS matrix (typically the weight tensor) should be allocated with transposed contracting dimensions ([B,] N, K) for optimal GEMM heuristics in cuBlasLt. This layout is also mandatory for FP8 inputs.

This PR does NOT update fused ops or Flax/Praxis modules to use the new GEMM custom op over the existing XLA pattern matching approach.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Added XLA custom calls for nvte_cublas_gemm.
  • Added JAX primitive for the new XLA custom call.
  • Added new serial unit test.
  • Add distributed unit test.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera added the jax label Nov 2, 2024
@denera denera requested review from nouiz and phu0ngng November 2, 2024 02:30
@denera denera self-assigned this Nov 2, 2024
@denera denera changed the title [JAX] Collective GEMM custom op with nvte_cublas_gemm [JAX] Collective GEMM custom op with nvte_cublas_gemm (no comm. overlap) Nov 2, 2024
@nouiz
Copy link
Collaborator

nouiz commented Nov 4, 2024

Why? Normal JAX behavior is to do some gathering.

@huanghua1994
Copy link
Collaborator

It seems that currently the batch size is not handled in the C++ code. Since JAX is using row-major storage for tensor by default, probably the batch dimension should be combined with the m dimension for LHS or the n dimension for RHS?

Signed-off-by: Alp Dener <[email protected]>

Added XLA FFI custom op for TE GEMM

Signed-off-by: Alp Dener <[email protected]>

finished GEMM custom op primitive and serial unit test

Signed-off-by: Alp Dener <[email protected]>

fixed GEMM custom op batcher

Signed-off-by: Alp Dener <[email protected]>

fixed output dtype error and contracting dimensions options

Signed-off-by: Alp Dener <[email protected]>

AG overlap working but executes scatter to match outer LHS dim

Signed-off-by: Alp Dener <[email protected]>

both all-gather and all-reduce are now working

Signed-off-by: Alp Dener <[email protected]>

code style

Signed-off-by: Alp Dener <[email protected]>

changed kwargs in abstract to be explicit

Signed-off-by: Alp Dener <[email protected]>

added fwd/bwd implementation for non-fp8 gemm

Signed-off-by: Alp Dener <[email protected]>
Copy link

@abhinavgoel95 abhinavgoel95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@denera I have some questions about the PR.


# Validate operand layouts
lhs_inner_dim, rhs_inner_dim = map(
lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@denera should be ndims + inner_dim when inner_dim is negative, right?

rhs_trans = contracting_dims[1] == rhs.ndim - 1
lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs
rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs
contracting_dims = (1, 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@denera is there a need to hard-code this?

Copy link
Collaborator Author

@denera denera Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuBlasLt GEMM requires non-transposed LHS and transposed RHS for FP8 GEMM, but the batcher is not the right place to check/force that. Also, leaving contracting_dims=(1, 1) out of the conditional for FP8 type is a mistake. Thanks for catching it!

grad=grad,
accumulate=accumulate,
use_split_accumulator=use_split_accumulator,
)(lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants