diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a857a95307..8dd7a36ae8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -4,8 +4,8 @@ """JAX/TE custom ops for cuBlasLt GEMM""" import warnings import operator -from functools import reduce -from typing import Optional, Tuple +import math +from typing import Optional, Tuple, Sequence import jax import jax.numpy as jnp @@ -20,7 +20,6 @@ from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper from .misc import ( - check_valid_batch_dims, jax_dtype_to_te_dtype, get_padded_spec, is_ffi_enabled, @@ -66,7 +65,7 @@ def abstract( out_amax_aval: ArrayLike, out_scale_aval: ArrayLike, out_dtype: jnp.dtype, - layout: str, + dimension_numbers: Tuple[Tuple[Tuple[Sequence[int], Sequence[int]], ...], ...], do_gelu: bool, use_bias: bool, grad: bool, @@ -96,23 +95,20 @@ def abstract( ), "Missing FP8 meta!" # Validate input layouts + contracting_dims, batched_dims = dimension_numbers + (lhs_inner_dim, ), (rhs_inner_dim, ) = contracting_dims + lhs_bdims, _ = batched_dims if is_fp8: - assert layout == 'NT', "FP8 GEMM custom op only supports 'NT' layout!" + assert lhs_inner_dim == lhs_aval.ndim - 1, "FP8 GEMM does not support transposed LHS." + assert rhs_inner_dim == -1, "FP8 GEMM requires transposed RHS." + rhs_trans = True else: - assert layout in ['NN', 'NT', 'TN'], "Invalid GEMM layout!" - lhs_trans = layout[0] == 'T' - rhs_trans = layout[1] == 'T' - lhs_outer_idx = -2 if lhs_trans else -1 - lhs_inner_idx = -1 if lhs_trans else -2 - rhs_outer_idx = -1 if rhs_trans else -2 - rhs_inner_idx = -2 if rhs_trans else -1 + rhs_trans = (rhs_inner_dim == -1) + assert ( - lhs_aval.shape[lhs_inner_idx] == rhs_aval.shape[rhs_inner_idx] + lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] ), "Incompatible operand sizes!" - assert all([ - lhs_batch == rhs_batch for lhs_batch, rhs_batch \ - in zip(lhs_aval.shape[:-2], rhs_aval.shape[:-2]) - ]), "Incompatible batch sizes!" + assert rhs_aval.ndim == 2, "TE/JAX GEMM does not support batched RHS operand." # Validate output dtype out_dtype = dtypes.canonicalize_dtype(out_dtype) @@ -133,12 +129,10 @@ def abstract( out_dtype = lhs_dtype # Infer output size and create abstract arrays - out_shape = ( - *lhs_aval.shape[:-2], - lhs_aval.shape[lhs_outer_idx], - rhs_aval.shape[rhs_outer_idx] - ) - out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + batched_shape = [lhs_aval.shape[bdim] for bdim in lhs_bdims] + projected_size = rhs_aval.shape[0 if rhs_trans else -1] + out_shape = (*batched_shape, projected_size) + out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) out_amax_updated_aval = out_amax_aval.update(shape=out_amax_aval.shape, dtype=jnp.float32) out_scale_updated_aval = out_scale_aval.update(shape=out_scale_aval.shape, dtype=jnp.float32) @@ -181,7 +175,7 @@ def lowering( out_amax: ArrayLike, out_scale: ArrayLike, out_dtype: jnp.dtype, - layout: str, + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]], ...], do_gelu: bool, use_bias: bool, grad: bool, @@ -191,12 +185,13 @@ def lowering( """ Fused attention fwd lowering rules """ - del do_gelu, use_bias + del do_gelu, use_bias, dimension_numbers - lhs_trans = layout[0] == 'T' - rhs_trans = layout[1] == 'T' - workspace_size = get_cublas_workspace_size_bytes() + # Batched always reshapes into LHS:([B], M, K) x RHS^T:(N, K) = OUT:([B], M, N) + lhs_trans = False + rhs_trans = True + # Call underlying custom op with flipped LHS/RHS to account for cuBlasLt column-major format if is_ffi_enabled(): name = "te_gemm_ffi" return ffi.ffi_lowering(name, operand_output_aliases={5: 1, 6: 2})( @@ -231,20 +226,16 @@ def lowering( ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - # LHS:([B], M, K) x RHS:([B], K, N) = OUT:([B], M, N) lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in - lhs_outer_idx = -2 if lhs_trans else -1 - lhs_inner_idx = -1 if lhs_trans else -2 - rhs_outer_idx = -1 if rhs_trans else -2 - batch = reduce(lhs_aval.shape[:-2], operator.mul, 1.) if lhs_aval.ndim > 2 else 1 - m = lhs_aval.shape[lhs_outer_idx] - n = lhs_aval.shape[lhs_inner_idx] - k = rhs_aval.shape[rhs_outer_idx] + m = lhs_aval.shape[0] + n = rhs_aval.shape[0] + k = rhs_aval.shape[-1] + workspace_size = get_cublas_workspace_size_bytes() operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) - opaque = tex.pack_gemm_descriptor(batch, m, n, k, workspace_size, operand_dtype, + opaque = tex.pack_gemm_descriptor(m, n, k, workspace_size, operand_dtype, jax_dtype_to_te_dtype(out_dtype), bias_dtype, - lhs_trans, rhs_trans, grad, accumulate, + lhs_trans, rhs_trans, grad, accumulate, use_split_accumulator) return custom_caller(GemmPrimitive.name, args, opaque, has_side_effect=False) @@ -259,7 +250,7 @@ def impl( out_amax: ArrayLike, out_scale: ArrayLike, out_dtype: jnp.dtype, - layout: str, + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]], ...], do_gelu: bool, use_bias: bool, grad: bool, @@ -278,7 +269,7 @@ def impl( out_amax, out_scale, out_dtype, - layout, + dimension_numbers, do_gelu, use_bias, grad, @@ -289,79 +280,89 @@ def impl( return output, out_amax_updated, out_scale_updated, pre_gelu_out @staticmethod - def batcher(batched_args, batch_dims, out_dtype, layout, do_gelu, use_bias, grad, accumulate, - use_split_accumulator): + def batcher(batched_args, batch_dims, out_dtype, dimension_numbers, do_gelu, use_bias, grad, + accumulate, use_split_accumulator): assert GemmPrimitive.outer_primitive is not None - check_valid_batch_dims(batch_dims) - _, _, b_bdim, _, amax_bdim, scale_bdim, _ = batch_dims - out_bdims = b_bdim, amax_bdim, scale_bdim, b_bdim - return ( - GemmPrimitive.outer_primitive.bind(*batched_args, out_dtype, layout, do_gelu, use_bias, - grad, accumulate, use_split_accumulator), - out_bdims, - ) + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale = batched_args + assert rhs.ndim == 2, "TE/JAX GEMM custom op does not support batching RHS operands." + + # Get contracting and batch dimensions out + contracting_dims, _ = dimension_numbers + (lhs_inner_dim, ), (rhs_inner_dim, ) = contracting_dims + lhs_bdims, _, _, _, _, amax_bdims, scale_bdims = batch_dims + + # Put the contracting dimension to the last dimension (if necessary) + # This means we always execute `nvte_cublas_gemm` with lhs_trans = False and + # rhs_trans = True + if lhs_inner_dim != lhs.ndim - 1: + lhs = jnp.moveaxis(lhs, lhs_inner_dim, -1) + if rhs_inner_dim != rhs.ndim - 1: + rhs = jnp.moveaxis(rhs, rhs_inner_dim, -1) + + # Collapse all non-contracting dimensions + lhs_batch_shape = lhs.shape[:-1] + lhs = jnp.reshape(lhs, (math.mul(lhs_batch_shape), lhs.shape[-1])) + + outputs = GemmPrimitive.outer_primitive.bind(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, + out_amax, out_scale, out_dtype, do_gelu, + use_bias, grad, accumulate, + use_split_accumulator) + + # Reshape output to recover original LHS batch shape + outputs[0] = jnp.reshape(outputs[0], (*lhs_batch_shape, rhs.shape[0])) + if outputs[3].size > 0: + outputs[3] = jnp.reshape(outputs[3], outputs[0].shape) + return outputs, (lhs_bdims, amax_bdims, scale_bdims, lhs_bdims) @staticmethod - def infer_sharding_from_operands(out_dtype, layout, do_gelu, use_bias, grad, accumulate, - use_split_accumulator, mesh, arg_infos, result_infos): + def infer_sharding_from_operands(out_dtype, dimension_numbers, do_gelu, use_bias, grad, + accumulate, use_split_accumulator, mesh, arg_infos, + result_infos): del out_dtype, do_gelu, use_bias, grad, accumulate, use_split_accumulator, result_infos + contracting_dims, batched_dims = dimension_numbers + (lhs_inner_dim, ), (rhs_inner_dim, ) = contracting_dims + lhs_bdims, _ = batched_dims + lhs_spec = get_padded_spec(arg_infos[0]) + batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_spec = get_padded_spec(arg_infos[2]) - lhs_trans = layout[0] == 'T' - rhs_trans = layout[1] == 'T' - - lhs_inner_idx = -1 if lhs_trans else -2 - rhs_inner_idx = -1 if rhs_trans else -2 - rhs_outer_idx = -2 if rhs_trans else -1 + projected_spec = rhs_spec[0 if rhs_inner_dim == 1 else 1] - if lhs_spec[lhs_inner_idx] != rhs_spec[rhs_inner_idx]: + if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim]: warnings.warn("Forcing the inner dimension of A to match the sharding of inner " + "dimension of B. This can trigger additional communication of A is " + "not partitioned correctly.") - if rhs_spec[rhs_inner_idx] is not None and rhs_spec[rhs_outer_idx] is not None: + if rhs_spec[rhs_inner_dim] is not None and projected_spec is not None: raise RuntimeError("Both inner and outer dimensions of B cannot be sharded!") - out_spec = [lhs_spec[rhs_outer_idx], rhs_spec[rhs_outer_idx]] - if len(lhs_spec) > 2: - out_spec = lhs_spec[:-2] + out_spec + out_spec = [*batch_specs, projected_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, out_sharding) @staticmethod - def partition(out_dtype, layout, do_gelu, use_bias, grad, accumulate, use_split_accumulator, - mesh, arg_infos, result_infos): + def partition(out_dtype, dimension_numbers, do_gelu, use_bias, grad, accumulate, + use_split_accumulator, mesh, arg_infos, result_infos): del out_dtype, do_gelu, use_bias, grad, accumulate, use_split_accumulator, result_infos + contracting_dims, batched_dims = dimension_numbers + (lhs_inner_dim, ), (rhs_inner_dim, ) = contracting_dims + lhs_bdims, _ = batched_dims + lhs_spec = get_padded_spec(arg_infos[0]) + batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_spec = get_padded_spec(arg_infos[2]) - lhs_trans = layout[0] == 'T' - rhs_trans = layout[1] == 'T' - - # LHS:([B], M, K) x RHS:([B], K, N) = OUT:([B], M, N) - lhs_outer_idx = -2 if lhs_trans else -1 - lhs_inner_idx = -1 if lhs_trans else -2 - rhs_inner_idx = -1 if rhs_trans else -2 - rhs_outer_idx = -2 if rhs_trans else -1 - - if lhs_spec[lhs_inner_idx] != rhs_spec[rhs_inner_idx]: - warnings.warn("Forcing the inner dimension of A to match the sharding of inner " - + "dimension of B. This can trigger additional communication of A is " - + "not partitioned correctly.") - if rhs_spec[rhs_inner_idx] is not None and rhs_spec[rhs_outer_idx] is not None: - raise RuntimeError("Both inner and outer dimensions of B cannot be sharded!") + projected_spec = rhs_spec[0 if rhs_inner_dim == 1 else 1] - lhs_spec_new = [None, rhs_spec[rhs_inner_idx]] - if len(lhs_spec) > 2: - lhs_spec_new = lhs_spec[:-2] + lhs_spec_new + lhs_spec_new = [None, lhs_spec[lhs_inner_dim]] + if len(batch_specs) > 1: + lhs_spec_new = batch_specs[:-1] + lhs_spec_new lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) - out_spec = [lhs_spec[rhs_outer_idx], rhs_spec[rhs_outer_idx]] - if len(lhs_spec) > 2: - out_spec = lhs_spec[:-2] + out_spec + out_spec = [*batch_specs, projected_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) bias_sharding = NamedSharding(mesh, PartitionSpec(out_spec[-1])) @@ -370,7 +371,7 @@ def partition(out_dtype, layout, do_gelu, use_bias, grad, accumulate, use_split_ out_shardings = (out_sharding, fp8_meta_sharding, fp8_meta_sharding, out_sharding) def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_dtype, - layout, do_gelu, use_bias, grad, accumulate, use_split_accumulator): + dimension_numbers, do_gelu, use_bias, grad, accumulate, use_split_accumulator): assert GemmPrimitive.inner_primitive is not None @@ -383,7 +384,7 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_ out_amax, out_scale, out_dtype, - layout, + dimension_numbers, do_gelu, use_bias, grad, @@ -391,12 +392,12 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_ use_split_accumulator ) - if rhs_spec[rhs_inner_idx] is not None: + if rhs_spec[rhs_inner_dim] is not None: # If the inner dimensions of LHS and RHS are sharded, we all-reduce the GEMM output. # If the outer dimension of LHS is also sharded, we reduce-scatter the GEMM output. par_op = ( jax.lax.psum_scatter - if lhs_spec[lhs_outer_idx] is not None + if batch_specs[-1] is not None else jax.lax.psum ) output = lax_paral_op(output, par_op, global_mesh_resource().tp_resource, mesh) @@ -421,11 +422,18 @@ def fp8_gemm( out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: Optional[jnp.dtype] = None, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = None, do_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, ) -> Tuple[ArrayLike, ...]: - """NT layout GEMM with FP8 inputs""" + """GEMM with FP8 inputs""" + if contracting_dims is None: + contracting_dims = ((1,), (0,)) + lhs_batch_dims = tuple([dim for dim in range(lhs.ndim) if dim not in contracting_dims[0]]) + rhs_batch_dims = tuple([dim for dim in range(rhs.ndim) if dim not in contracting_dims[1]]) + dimension_numbers = (contracting_dims, (lhs_batch_dims, rhs_batch_dims)) + if out_dtype is not None and is_fp8_dtype(out_dtype): assert out_amax is not None and out_scale is not None, "Missing output amax and scale!" else: @@ -437,8 +445,8 @@ def fp8_gemm( bias = jnp.empty((0, ), dtype=jnp.bfloat16) out, out_amax, out_scale, pre_gelu_out = GemmPrimitive.outer_primitive.bind( - lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_dtype, 'NT', do_gelu, - use_bias, False, accumulate, use_split_accumulator + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, out_amax, out_scale, out_dtype, + dimension_numbers, do_gelu, use_bias, False, accumulate, use_split_accumulator ) outputs = (out, ) @@ -451,28 +459,34 @@ def fp8_gemm( def gemm( - A: ArrayLike, - B: ArrayLike, + lhs: ArrayLike, + rhs: ArrayLike, bias: Optional[ArrayLike] = None, - layout: str = 'NT', + contracting_dims: Tuple[Sequence[int], Sequence[int]] = None, do_gelu: bool = False, grad: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, ) -> Tuple[ArrayLike, ...]: """Non-FP8 GEMM""" + if contracting_dims is None: + contracting_dims = ((1,), (0,)) + lhs_batch_dims = tuple([dim for dim in range(lhs.ndim) if dim not in contracting_dims[0]]) + rhs_batch_dims = tuple([dim for dim in range(rhs.ndim) if dim not in contracting_dims[1]]) + dimension_numbers = (contracting_dims, (lhs_batch_dims, rhs_batch_dims)) + use_bias = bias is not None if not use_bias: bias = jnp.empty((0, ), dtype=jnp.bfloat16) dummy_fp8_meta = jnp.empty((0, ), dtype=jnp.float32) - D, _, _, pre_gelu_out = GemmPrimitive.outer_primitive.bind( - A, dummy_fp8_meta, B, dummy_fp8_meta, bias, dummy_fp8_meta, dummy_fp8_meta, A.dtype, - layout, do_gelu, use_bias, grad, accumulate, use_split_accumulator + out, _, _, pre_gelu_out = GemmPrimitive.outer_primitive.bind( + lhs, dummy_fp8_meta, rhs, dummy_fp8_meta, bias, dummy_fp8_meta, dummy_fp8_meta, lhs.dtype, + dimension_numbers, do_gelu, use_bias, grad, accumulate, use_split_accumulator ) - outputs = (D, ) + outputs = (out, ) if do_gelu: outputs += (pre_gelu_out, )