-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[C/JAX] Comm+GEMM Overlap API for TE/JAX #1337
base: main
Are you sure you want to change the base?
Commits on Nov 14, 2024
-
added XLA custom op defs for TE GEMM
Signed-off-by: Alp Dener <[email protected]> Added XLA FFI custom op for TE GEMM Signed-off-by: Alp Dener <[email protected]> finished GEMM custom op primitive and serial unit test Signed-off-by: Alp Dener <[email protected]> fixed GEMM custom op batcher Signed-off-by: Alp Dener <[email protected]> fixed output dtype error and contracting dimensions options Signed-off-by: Alp Dener <[email protected]> AG overlap working but executes scatter to match outer LHS dim Signed-off-by: Alp Dener <[email protected]> both all-gather and all-reduce are now working Signed-off-by: Alp Dener <[email protected]> code style Signed-off-by: Alp Dener <[email protected]> changed kwargs in abstract to be explicit Signed-off-by: Alp Dener <[email protected]> added fwd/bwd implementation for non-fp8 gemm Signed-off-by: Alp Dener <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for 941f5bb - Browse repository at this point
Copy the full SHA 941f5bbView commit details -
fixed batching rules to accommodated batched RHS operand for GEMM
Signed-off-by: Alp Dener <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for f440094 - Browse repository at this point
Copy the full SHA f440094View commit details -
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Configuration menu - View commit details
-
Copy full SHA for 52af237 - Browse repository at this point
Copy the full SHA 52af237View commit details
Commits on Nov 15, 2024
-
fixed batching for collective GEMM FWD and BWD
Signed-off-by: Alp Dener <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for 378721c - Browse repository at this point
Copy the full SHA 378721cView commit details -
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Configuration menu - View commit details
-
Copy full SHA for 46693fb - Browse repository at this point
Copy the full SHA 46693fbView commit details -
propagated batching fixes to fp8_gemm backward pass
Signed-off-by: Alp Dener <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for aa1600f - Browse repository at this point
Copy the full SHA aa1600fView commit details -
Merge branch 'jax-collective-gemm' of github.com:denera/TransformerEn…
…gine into jax-collective-gemm
Configuration menu - View commit details
-
Copy full SHA for 30b7b06 - Browse repository at this point
Copy the full SHA 30b7b06View commit details -
added XLA custom ops and C++ infrastructure for comm+GEMM overlap in …
…TE/JAX Signed-off-by: Alp Dener <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for ad31fbc - Browse repository at this point
Copy the full SHA ad31fbcView commit details -
comm+GEMM overlap API for TE/JAX compiles, untested, but did not brea…
…k collective GEMM op Signed-off-by: Alp Dener <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for cf1dfa4 - Browse repository at this point
Copy the full SHA cf1dfa4View commit details -
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Configuration menu - View commit details
-
Copy full SHA for c8c94e6 - Browse repository at this point
Copy the full SHA c8c94e6View commit details