Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/JAX] Comm+GEMM Overlap API for TE/JAX #1337

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Commits on Nov 14, 2024

  1. added XLA custom op defs for TE GEMM

    Signed-off-by: Alp Dener <[email protected]>
    
    Added XLA FFI custom op for TE GEMM
    
    Signed-off-by: Alp Dener <[email protected]>
    
    finished GEMM custom op primitive and serial unit test
    
    Signed-off-by: Alp Dener <[email protected]>
    
    fixed GEMM custom op batcher
    
    Signed-off-by: Alp Dener <[email protected]>
    
    fixed output dtype error and contracting dimensions options
    
    Signed-off-by: Alp Dener <[email protected]>
    
    AG overlap working but executes scatter to match outer LHS dim
    
    Signed-off-by: Alp Dener <[email protected]>
    
    both all-gather and all-reduce are now working
    
    Signed-off-by: Alp Dener <[email protected]>
    
    code style
    
    Signed-off-by: Alp Dener <[email protected]>
    
    changed kwargs in abstract to be explicit
    
    Signed-off-by: Alp Dener <[email protected]>
    
    added fwd/bwd implementation for non-fp8 gemm
    
    Signed-off-by: Alp Dener <[email protected]>
    denera committed Nov 14, 2024
    Configuration menu
    Copy the full SHA
    941f5bb View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    f440094 View commit details
    Browse the repository at this point in the history
  3. Configuration menu
    Copy the full SHA
    52af237 View commit details
    Browse the repository at this point in the history

Commits on Nov 15, 2024

  1. fixed batching for collective GEMM FWD and BWD

    Signed-off-by: Alp Dener <[email protected]>
    denera committed Nov 15, 2024
    Configuration menu
    Copy the full SHA
    378721c View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    46693fb View commit details
    Browse the repository at this point in the history
  3. propagated batching fixes to fp8_gemm backward pass

    Signed-off-by: Alp Dener <[email protected]>
    denera committed Nov 15, 2024
    Configuration menu
    Copy the full SHA
    aa1600f View commit details
    Browse the repository at this point in the history
  4. Merge branch 'jax-collective-gemm' of github.com:denera/TransformerEn…

    …gine into jax-collective-gemm
    denera committed Nov 15, 2024
    Configuration menu
    Copy the full SHA
    30b7b06 View commit details
    Browse the repository at this point in the history
  5. added XLA custom ops and C++ infrastructure for comm+GEMM overlap in …

    …TE/JAX
    
    Signed-off-by: Alp Dener <[email protected]>
    denera committed Nov 15, 2024
    Configuration menu
    Copy the full SHA
    ad31fbc View commit details
    Browse the repository at this point in the history
  6. comm+GEMM overlap API for TE/JAX compiles, untested, but did not brea…

    …k collective GEMM op
    
    Signed-off-by: Alp Dener <[email protected]>
    denera committed Nov 15, 2024
    Configuration menu
    Copy the full SHA
    cf1dfa4 View commit details
    Browse the repository at this point in the history
  7. Configuration menu
    Copy the full SHA
    c8c94e6 View commit details
    Browse the repository at this point in the history