Skip to content

Commit

Permalink
Merge branch 'jax-collective-gemm' of github.com:denera/TransformerEn…
Browse files Browse the repository at this point in the history
…gine into jax-collective-gemm
  • Loading branch information
denera committed Nov 15, 2024
2 parents aa1600f + 46693fb commit 30b7b06
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
6 changes: 2 additions & 4 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"gemm_impl",
]


def sanitize_dims(dim, ndims):
return ndims + dim if dim < 0 else dim

Expand Down Expand Up @@ -397,10 +398,7 @@ def batcher(
use_split_accumulator=use_split_accumulator,
)

return (
outputs,
(lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims)
)
return (outputs, (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims))

@staticmethod
def infer_sharding_from_operands(
Expand Down
6 changes: 2 additions & 4 deletions transformer_engine/jax/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def _gemm_bwd_rule(
x = jnp.matrix_transpose(x)
batch_size = reduce(operator.mul, x.shape[:-2], 1)
x = jnp.reshape(x, (batch_size * x.shape[-2], x.shape[-1]))
wgrad_rhs = jnp.reshape(
wgrad_rhs, (batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1])
)
wgrad_rhs = jnp.reshape(wgrad_rhs, (batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1]))

# WGRAD: ([B], M, K)^T x ([B], M, N) = (K, N)
wgrad, _, bgrad = gemm_impl(
Expand Down Expand Up @@ -233,7 +231,7 @@ def _fp8_gemm_fwd_rule(
x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]:
casted_x, casted_x_t, updated_x_amax = cast_transpose(
casted_x, casted_x_t, updated_x_amax = cast_transpose(
x,
x_amax,
x_scale,
Expand Down

0 comments on commit 30b7b06

Please sign in to comment.