Skip to content

Commit

Permalink
fixed GEMM custom op batcher
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Nov 6, 2024
1 parent 6277d22 commit fea0728
Showing 1 changed file with 114 additions and 100 deletions.
214 changes: 114 additions & 100 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""JAX/TE custom ops for cuBlasLt GEMM"""
import warnings
import operator
from functools import reduce
from typing import Optional, Tuple
import math
from typing import Optional, Tuple, Sequence

import jax
import jax.numpy as jnp
Expand All @@ -20,7 +20,6 @@
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype,
get_padded_spec,
is_ffi_enabled,
Expand Down Expand Up @@ -66,7 +65,7 @@ def abstract(
out_amax_aval: ArrayLike,
out_scale_aval: ArrayLike,
out_dtype: jnp.dtype,
layout: str,
dimension_numbers: Tuple[Tuple[Tuple[Sequence[int], Sequence[int]], ...], ...],
do_gelu: bool,
use_bias: bool,
grad: bool,
Expand Down Expand Up @@ -96,23 +95,20 @@ def abstract(
), "Missing FP8 meta!"

# Validate input layouts
contracting_dims, batched_dims = dimension_numbers
(lhs_inner_dim, ), (rhs_inner_dim, ) = contracting_dims
lhs_bdims, _ = batched_dims
if is_fp8:
assert layout == 'NT', "FP8 GEMM custom op only supports 'NT' layout!"
assert lhs_inner_dim == lhs_aval.ndim - 1, "FP8 GEMM does not support transposed LHS."
assert rhs_inner_dim == -1, "FP8 GEMM requires transposed RHS."
rhs_trans = True
else:
assert layout in ['NN', 'NT', 'TN'], "Invalid GEMM layout!"
lhs_trans = layout[0] == 'T'
rhs_trans = layout[1] == 'T'
lhs_outer_idx = -2 if lhs_trans else -1
lhs_inner_idx = -1 if lhs_trans else -2
rhs_outer_idx = -1 if rhs_trans else -2
rhs_inner_idx = -2 if rhs_trans else -1
rhs_trans = (rhs_inner_dim == -1)

assert (
lhs_aval.shape[lhs_inner_idx] == rhs_aval.shape[rhs_inner_idx]
lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim]
), "Incompatible operand sizes!"
assert all([
lhs_batch == rhs_batch for lhs_batch, rhs_batch \
in zip(lhs_aval.shape[:-2], rhs_aval.shape[:-2])
]), "Incompatible batch sizes!"
assert rhs_aval.ndim == 2, "TE/JAX GEMM does not support batched RHS operand."

# Validate output dtype
out_dtype = dtypes.canonicalize_dtype(out_dtype)
Expand All @@ -133,12 +129,10 @@ def abstract(
out_dtype = lhs_dtype

# Infer output size and create abstract arrays
out_shape = (
*lhs_aval.shape[:-2],
lhs_aval.shape[lhs_outer_idx],
rhs_aval.shape[rhs_outer_idx]
)
out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
batched_shape = [lhs_aval.shape[bdim] for bdim in lhs_bdims]
projected_size = rhs_aval.shape[0 if rhs_trans else -1]
out_shape = (*batched_shape, projected_size)
out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype)
out_amax_updated_aval = out_amax_aval.update(shape=out_amax_aval.shape, dtype=jnp.float32)
out_scale_updated_aval = out_scale_aval.update(shape=out_scale_aval.shape, dtype=jnp.float32)

Expand Down Expand Up @@ -181,7 +175,7 @@ def lowering(
out_amax: ArrayLike,
out_scale: ArrayLike,
out_dtype: jnp.dtype,
layout: str,
dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]], ...],
do_gelu: bool,
use_bias: bool,
grad: bool,
Expand All @@ -191,12 +185,13 @@ def lowering(
"""
Fused attention fwd lowering rules
"""
del do_gelu, use_bias
del do_gelu, use_bias, dimension_numbers

lhs_trans = layout[0] == 'T'
rhs_trans = layout[1] == 'T'
workspace_size = get_cublas_workspace_size_bytes()
# Batched always reshapes into LHS:([B], M, K) x RHS^T:(N, K) = OUT:([B], M, N)
lhs_trans = False
rhs_trans = True

# Call underlying custom op with flipped LHS/RHS to account for cuBlasLt column-major format
if is_ffi_enabled():
name = "te_gemm_ffi"
return ffi.ffi_lowering(name, operand_output_aliases={5: 1, 6: 2})(
Expand Down Expand Up @@ -231,20 +226,16 @@ def lowering(
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

# LHS:([B], M, K) x RHS:([B], K, N) = OUT:([B], M, N)
lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in
lhs_outer_idx = -2 if lhs_trans else -1
lhs_inner_idx = -1 if lhs_trans else -2
rhs_outer_idx = -1 if rhs_trans else -2
batch = reduce(lhs_aval.shape[:-2], operator.mul, 1.) if lhs_aval.ndim > 2 else 1
m = lhs_aval.shape[lhs_outer_idx]
n = lhs_aval.shape[lhs_inner_idx]
k = rhs_aval.shape[rhs_outer_idx]
m = lhs_aval.shape[0]
n = rhs_aval.shape[0]
k = rhs_aval.shape[-1]
workspace_size = get_cublas_workspace_size_bytes()
operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype)
bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype)
opaque = tex.pack_gemm_descriptor(batch, m, n, k, workspace_size, operand_dtype,
opaque = tex.pack_gemm_descriptor(m, n, k, workspace_size, operand_dtype,
jax_dtype_to_te_dtype(out_dtype), bias_dtype,
lhs_trans, rhs_trans, grad, accumulate,
lhs_trans, rhs_trans, grad, accumulate,
use_split_accumulator)

return custom_caller(GemmPrimitive.name, args, opaque, has_side_effect=False)
Expand All @@ -259,7 +250,7 @@ def impl(
out_amax: ArrayLike,
out_scale: ArrayLike,
out_dtype: jnp.dtype,
layout: str,
dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]], ...],
do_gelu: bool,
use_bias: bool,
grad: bool,
Expand All @@ -278,7 +269,7 @@ def impl(
out_amax,
out_scale,
out_dtype,
layout,
dimension_numbers,
do_gelu,
use_bias,
grad,
Expand All @@ -289,79 +280,89 @@ def impl(
return output, out_amax_updated, out_scale_updated, pre_gelu_out

@staticmethod
def batcher(batched_args, batch_dims, out_dtype, layout, do_gelu, use_bias, grad, accumulate,
use_split_accumulator):
def batcher(batched_args, batch_dims, out_dtype, dimension_numbers, do_gelu, use_bias, grad,
accumulate, use_split_accumulator):
assert GemmPrimitive.outer_primitive is not None
check_valid_batch_dims(batch_dims)
_, _, b_bdim, _, amax_bdim, scale_bdim, _ = batch_dims

out_bdims = b_bdim, amax_bdim, scale_bdim, b_bdim
return (
GemmPrimitive.outer_primitive.bind(*batched_args, out_dtype, layout, do_gelu, use_bias,
grad, accumulate, use_split_accumulator),
out_bdims,
)
lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale = batched_args
assert rhs.ndim == 2, "TE/JAX GEMM custom op does not support batching RHS operands."

# Get contracting and batch dimensions out
contracting_dims, _ = dimension_numbers
(lhs_inner_dim, ), (rhs_inner_dim, ) = contracting_dims
lhs_bdims, _, _, _, _, amax_bdims, scale_bdims = batch_dims

# Put the contracting dimension to the last dimension (if necessary)
# This means we always execute `nvte_cublas_gemm` with lhs_trans = False and
# rhs_trans = True
if lhs_inner_dim != lhs.ndim - 1:
lhs = jnp.moveaxis(lhs, lhs_inner_dim, -1)
if rhs_inner_dim != rhs.ndim - 1:
rhs = jnp.moveaxis(rhs, rhs_inner_dim, -1)

# Collapse all non-contracting dimensions
lhs_batch_shape = lhs.shape[:-1]
lhs = jnp.reshape(lhs, (math.mul(lhs_batch_shape), lhs.shape[-1]))

outputs = GemmPrimitive.outer_primitive.bind(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias,
out_amax, out_scale, out_dtype, do_gelu,
use_bias, grad, accumulate,
use_split_accumulator)

# Reshape output to recover original LHS batch shape
outputs[0] = jnp.reshape(outputs[0], (*lhs_batch_shape, rhs.shape[0]))
if outputs[3].size > 0:
outputs[3] = jnp.reshape(outputs[3], outputs[0].shape)
return outputs, (lhs_bdims, amax_bdims, scale_bdims, lhs_bdims)

@staticmethod
def infer_sharding_from_operands(out_dtype, layout, do_gelu, use_bias, grad, accumulate,
use_split_accumulator, mesh, arg_infos, result_infos):
def infer_sharding_from_operands(out_dtype, dimension_numbers, do_gelu, use_bias, grad,
accumulate, use_split_accumulator, mesh, arg_infos,
result_infos):
del out_dtype, do_gelu, use_bias, grad, accumulate, use_split_accumulator, result_infos
contracting_dims, batched_dims = dimension_numbers
(lhs_inner_dim, ), (rhs_inner_dim, ) = contracting_dims
lhs_bdims, _ = batched_dims

lhs_spec = get_padded_spec(arg_infos[0])
batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims]
rhs_spec = get_padded_spec(arg_infos[2])
lhs_trans = layout[0] == 'T'
rhs_trans = layout[1] == 'T'

lhs_inner_idx = -1 if lhs_trans else -2
rhs_inner_idx = -1 if rhs_trans else -2
rhs_outer_idx = -2 if rhs_trans else -1
projected_spec = rhs_spec[0 if rhs_inner_dim == 1 else 1]

if lhs_spec[lhs_inner_idx] != rhs_spec[rhs_inner_idx]:
if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim]:
warnings.warn("Forcing the inner dimension of A to match the sharding of inner "
+ "dimension of B. This can trigger additional communication of A is "
+ "not partitioned correctly.")
if rhs_spec[rhs_inner_idx] is not None and rhs_spec[rhs_outer_idx] is not None:
if rhs_spec[rhs_inner_dim] is not None and projected_spec is not None:
raise RuntimeError("Both inner and outer dimensions of B cannot be sharded!")

out_spec = [lhs_spec[rhs_outer_idx], rhs_spec[rhs_outer_idx]]
if len(lhs_spec) > 2:
out_spec = lhs_spec[:-2] + out_spec
out_spec = [*batch_specs, projected_spec]
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec))
fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None))

return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, out_sharding)

@staticmethod
def partition(out_dtype, layout, do_gelu, use_bias, grad, accumulate, use_split_accumulator,
mesh, arg_infos, result_infos):
def partition(out_dtype, dimension_numbers, do_gelu, use_bias, grad, accumulate,
use_split_accumulator, mesh, arg_infos, result_infos):
del out_dtype, do_gelu, use_bias, grad, accumulate, use_split_accumulator, result_infos
contracting_dims, batched_dims = dimension_numbers
(lhs_inner_dim, ), (rhs_inner_dim, ) = contracting_dims
lhs_bdims, _ = batched_dims

lhs_spec = get_padded_spec(arg_infos[0])
batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims]
rhs_spec = get_padded_spec(arg_infos[2])
lhs_trans = layout[0] == 'T'
rhs_trans = layout[1] == 'T'

# LHS:([B], M, K) x RHS:([B], K, N) = OUT:([B], M, N)
lhs_outer_idx = -2 if lhs_trans else -1
lhs_inner_idx = -1 if lhs_trans else -2
rhs_inner_idx = -1 if rhs_trans else -2
rhs_outer_idx = -2 if rhs_trans else -1

if lhs_spec[lhs_inner_idx] != rhs_spec[rhs_inner_idx]:
warnings.warn("Forcing the inner dimension of A to match the sharding of inner "
+ "dimension of B. This can trigger additional communication of A is "
+ "not partitioned correctly.")
if rhs_spec[rhs_inner_idx] is not None and rhs_spec[rhs_outer_idx] is not None:
raise RuntimeError("Both inner and outer dimensions of B cannot be sharded!")
projected_spec = rhs_spec[0 if rhs_inner_dim == 1 else 1]

lhs_spec_new = [None, rhs_spec[rhs_inner_idx]]
if len(lhs_spec) > 2:
lhs_spec_new = lhs_spec[:-2] + lhs_spec_new
lhs_spec_new = [None, lhs_spec[lhs_inner_dim]]
if len(batch_specs) > 1:
lhs_spec_new = batch_specs[:-1] + lhs_spec_new
lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new))
rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec))
fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None))

out_spec = [lhs_spec[rhs_outer_idx], rhs_spec[rhs_outer_idx]]
if len(lhs_spec) > 2:
out_spec = lhs_spec[:-2] + out_spec
out_spec = [*batch_specs, projected_spec]
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec))
bias_sharding = NamedSharding(mesh, PartitionSpec(out_spec[-1]))

Expand All @@ -370,7 +371,7 @@ def partition(out_dtype, layout, do_gelu, use_bias, grad, accumulate, use_split_
out_shardings = (out_sharding, fp8_meta_sharding, fp8_meta_sharding, out_sharding)

def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_dtype,
layout, do_gelu, use_bias, grad, accumulate, use_split_accumulator):
dimension_numbers, do_gelu, use_bias, grad, accumulate, use_split_accumulator):

assert GemmPrimitive.inner_primitive is not None

Expand All @@ -383,20 +384,20 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_
out_amax,
out_scale,
out_dtype,
layout,
dimension_numbers,
do_gelu,
use_bias,
grad,
accumulate,
use_split_accumulator
)

if rhs_spec[rhs_inner_idx] is not None:
if rhs_spec[rhs_inner_dim] is not None:
# If the inner dimensions of LHS and RHS are sharded, we all-reduce the GEMM output.
# If the outer dimension of LHS is also sharded, we reduce-scatter the GEMM output.
par_op = (
jax.lax.psum_scatter
if lhs_spec[lhs_outer_idx] is not None
if batch_specs[-1] is not None
else jax.lax.psum
)
output = lax_paral_op(output, par_op, global_mesh_resource().tp_resource, mesh)
Expand All @@ -421,11 +422,18 @@ def fp8_gemm(
out_amax: Optional[ArrayLike] = None,
out_scale: Optional[ArrayLike] = None,
out_dtype: Optional[jnp.dtype] = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = None,
do_gelu: bool = False,
accumulate: bool = False,
use_split_accumulator: bool = False,
) -> Tuple[ArrayLike, ...]:
"""NT layout GEMM with FP8 inputs"""
"""GEMM with FP8 inputs"""
if contracting_dims is None:
contracting_dims = ((1,), (0,))
lhs_batch_dims = tuple([dim for dim in range(lhs.ndim) if dim not in contracting_dims[0]])
rhs_batch_dims = tuple([dim for dim in range(rhs.ndim) if dim not in contracting_dims[1]])
dimension_numbers = (contracting_dims, (lhs_batch_dims, rhs_batch_dims))

if out_dtype is not None and is_fp8_dtype(out_dtype):
assert out_amax is not None and out_scale is not None, "Missing output amax and scale!"
else:
Expand All @@ -437,8 +445,8 @@ def fp8_gemm(
bias = jnp.empty((0, ), dtype=jnp.bfloat16)

out, out_amax, out_scale, pre_gelu_out = GemmPrimitive.outer_primitive.bind(
lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_dtype, 'NT', do_gelu,
use_bias, False, accumulate, use_split_accumulator
lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_dtype,
dimension_numbers, do_gelu, use_bias, False, accumulate, use_split_accumulator
)

outputs = (out, )
Expand All @@ -451,28 +459,34 @@ def fp8_gemm(


def gemm(
A: ArrayLike,
B: ArrayLike,
lhs: ArrayLike,
rhs: ArrayLike,
bias: Optional[ArrayLike] = None,
layout: str = 'NT',
contracting_dims: Tuple[Sequence[int], Sequence[int]] = None,
do_gelu: bool = False,
grad: bool = False,
accumulate: bool = False,
use_split_accumulator: bool = False,
) -> Tuple[ArrayLike, ...]:
"""Non-FP8 GEMM"""
if contracting_dims is None:
contracting_dims = ((1,), (0,))
lhs_batch_dims = tuple([dim for dim in range(lhs.ndim) if dim not in contracting_dims[0]])
rhs_batch_dims = tuple([dim for dim in range(rhs.ndim) if dim not in contracting_dims[1]])
dimension_numbers = (contracting_dims, (lhs_batch_dims, rhs_batch_dims))

use_bias = bias is not None
if not use_bias:
bias = jnp.empty((0, ), dtype=jnp.bfloat16)

dummy_fp8_meta = jnp.empty((0, ), dtype=jnp.float32)

D, _, _, pre_gelu_out = GemmPrimitive.outer_primitive.bind(
A, dummy_fp8_meta, B, dummy_fp8_meta, bias, dummy_fp8_meta, dummy_fp8_meta, A.dtype,
layout, do_gelu, use_bias, grad, accumulate, use_split_accumulator
out, _, _, pre_gelu_out = GemmPrimitive.outer_primitive.bind(
lhs, dummy_fp8_meta, rhs, dummy_fp8_meta, bias, dummy_fp8_meta, dummy_fp8_meta, lhs.dtype,
dimension_numbers, do_gelu, use_bias, grad, accumulate, use_split_accumulator
)

outputs = (D, )
outputs = (out, )
if do_gelu:
outputs += (pre_gelu_out, )

Expand Down

0 comments on commit fea0728

Please sign in to comment.