Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/JAX] Comm+GEMM Overlap API for TE/JAX #1337

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

denera
Copy link
Collaborator

@denera denera commented Nov 15, 2024

Description

>>> Depends on PR #1307 <<<

This PR implements JAX/XLA custom ops and primitives for comm+GEMM overlap kernels in TE/common, and the pure Python/JAX infrastructure required to bootstrap the functionality.

Current limitations and considerations:

  • Requires a distributed launch with 1 process per GPU and execution with jax.distributed.initialize(). JAX does not have its own distributed launch utility like torchrun, so this is typically done with mpirun launch + mpi4py in Python.
  • TE has to be compiled with NVTE_UB_WITH_MPI=1 and Userbuffers has to be bootstrapped with MPI because XLA custom ops cannot execute XLA collectives. Unlike PyTorch, this does not introduce a new dependency because distributed launch with JAX already depends on MPI.
  • Userbuffers communication buffers are allocated outside of the XLA memory pool. Since XLA has no knowledge of these allocations, its memory allowance as a % of total device memory needs to be decreased to avoid OOM issues.

To do:
[x] Implement XLA custom ops w/ both old API and new FFI interfaces.
[x] Extend JAX CollectiveGemmPrimitive to support comm+GEMM overlap.
[x] Implement bootstrapping and utility functions with PyBind11 bindings.
[x] Verify that comm+GEMM overlap extensions do not break non-overlap collective GEMM functionality.
[ ] Add new unit tests for comm+GEMM overlap.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

denera and others added 9 commits November 14, 2024 09:30
Signed-off-by: Alp Dener <[email protected]>

Added XLA FFI custom op for TE GEMM

Signed-off-by: Alp Dener <[email protected]>

finished GEMM custom op primitive and serial unit test

Signed-off-by: Alp Dener <[email protected]>

fixed GEMM custom op batcher

Signed-off-by: Alp Dener <[email protected]>

fixed output dtype error and contracting dimensions options

Signed-off-by: Alp Dener <[email protected]>

AG overlap working but executes scatter to match outer LHS dim

Signed-off-by: Alp Dener <[email protected]>

both all-gather and all-reduce are now working

Signed-off-by: Alp Dener <[email protected]>

code style

Signed-off-by: Alp Dener <[email protected]>

changed kwargs in abstract to be explicit

Signed-off-by: Alp Dener <[email protected]>

added fwd/bwd implementation for non-fp8 gemm

Signed-off-by: Alp Dener <[email protected]>
@denera denera added enhancement New feature or request jax labels Nov 15, 2024
@denera denera self-assigned this Nov 15, 2024
@huanghua1994 huanghua1994 self-requested a review November 15, 2024 17:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant