-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Collective GEMM custom op with nvte_cublas_gemm
(no comm. overlap)
#1307
base: main
Are you sure you want to change the base?
Conversation
nvte_cublas_gemm
nvte_cublas_gemm
(no comm. overlap)
Why? Normal JAX behavior is to do some gathering. |
It seems that currently the batch size is not handled in the C++ code. Since JAX is using row-major storage for tensor by default, probably the batch dimension should be combined with the |
bb2be56
to
fea0728
Compare
Signed-off-by: Alp Dener <[email protected]> Added XLA FFI custom op for TE GEMM Signed-off-by: Alp Dener <[email protected]> finished GEMM custom op primitive and serial unit test Signed-off-by: Alp Dener <[email protected]> fixed GEMM custom op batcher Signed-off-by: Alp Dener <[email protected]> fixed output dtype error and contracting dimensions options Signed-off-by: Alp Dener <[email protected]> AG overlap working but executes scatter to match outer LHS dim Signed-off-by: Alp Dener <[email protected]> both all-gather and all-reduce are now working Signed-off-by: Alp Dener <[email protected]> code style Signed-off-by: Alp Dener <[email protected]> changed kwargs in abstract to be explicit Signed-off-by: Alp Dener <[email protected]> added fwd/bwd implementation for non-fp8 gemm Signed-off-by: Alp Dener <[email protected]>
3ec3eca
to
941f5bb
Compare
Signed-off-by: Alp Dener <[email protected]>
6444211
to
f440094
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera I have some questions about the PR.
|
||
# Validate operand layouts | ||
lhs_inner_dim, rhs_inner_dim = map( | ||
lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera should be ndims + inner_dim
when inner_dim is negative, right?
rhs_trans = contracting_dims[1] == rhs.ndim - 1 | ||
lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs | ||
rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs | ||
contracting_dims = (1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera is there a need to hard-code this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuBlasLt GEMM requires non-transposed LHS and transposed RHS for FP8 GEMM, but the batcher is not the right place to check/force that. Also, leaving contracting_dims=(1, 1)
out of the conditional for FP8 type is a mistake. Thanks for catching it!
grad=grad, | ||
accumulate=accumulate, | ||
use_split_accumulator=use_split_accumulator, | ||
)(lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives me an error.
Line: https://github.com/NVIDIA/TransformerEngine/pull/1307/files#diff-f5b74ca3c5a70acb3d764e9b8adea40b8bab554fe4d2362f3052b7b932c0464dR187-R194 returns a tuple.
TypeError: 'list' object is not callable
cc @denera
… passing test Signed-off-by: Alp Dener <[email protected]>
30b7b06
to
b989641
Compare
for more information, see https://pre-commit.ci
Description
Implements both old-style and new FFI-based XLA custom calls in C++, and the corresponding JAX primitive including custom partitioning rules.
Custom partitioning rules for a
LHS:([B,] M, K) x RHS:([B,] K, N) = OUT:([B,] M, N)
batched mat-mul operation where[B]
is the batch dimension:[B]
dimension for all operands.M
dimension.K
andN
dimensions.K
dimension of LHS to match the partitioning of theK
dimension of RHS.K
dimension is partitioned butM
dimension is not,jax.lax.psum
(all-reduce) the output over the TP mesh resource.M
andK
dimensions are partitioned,jax.lax.psum_scatter
(reduce-scatter) the output over the TP mesh resource.In practice, the RHS matrix (typically the weight tensor) should be allocated with transposed contracting dimensions
([B,] N, K)
for optimal GEMM heuristics in cuBlasLt. This layout is also mandatory for FP8 inputs.This PR does NOT update fused ops or Flax/Praxis modules to use the new GEMM custom op over the existing XLA pattern matching approach.
Type of change
Changes
nvte_cublas_gemm
.Checklist: