Skip to content

Commit

Permalink
finished GEMM custom op primitive and serial unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Nov 2, 2024
1 parent 985e0b1 commit 6277d22
Show file tree
Hide file tree
Showing 8 changed files with 675 additions and 95 deletions.
55 changes: 55 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_jax_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax.gemm import fp8_gemm, gemm
from transformer_engine.jax import cpp_extensions as tex


Expand Down Expand Up @@ -414,6 +415,60 @@ def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_
)


class TestGemm:

@staticmethod
def _generate_inputs(b, m, n, k, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
a = jax.random.normal(subkeys[0], (b, m, k), dtype)
b = jax.random.normal(subkeys[1], (n, k), dtype)
bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16
bias = jax.random.normal(subkeys[2], (n, ), bias_dtype)
return a, b, bias

@staticmethod
def _generate_fp8_inputs(b, m, n, k, fp8_dtype):
a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16)
a_scale, b_scale = map(
lambda x: (jnp.max(jnp.abs(x)) / 127.).astype(jnp.float32),
[a, b]
)
a_q, b_q = map(
lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype),
[(a, a_scale), (b, b_scale)]
)
return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias

@pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize("use_bias", (False, True))
@pytest.mark.parametrize("do_gelu", (False, True))
def test_gemm(self, b, m, n, k, use_bias, do_gelu):
a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16)

primitive_out = gemm(a, b, bias=bias if use_bias else None, layout='NT', do_gelu=do_gelu)
ref_out = jnp.dot(a, b)
if use_bias:
ref_out += bias
if do_gelu:
ref_out = jax.nn.gelu(ref_out)

assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE)
def test_fp8_gemm(self, m, n, k, fp8_dtype):
a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(
m, n, k, fp8_dtype
)

primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16)
ref_out = jnp.dot(a, b)

assert_allclose(primitive_out, ref_out, dtype=fp8_dtype)


@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
key = jax.random.PRNGKey(0)
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/jax/cpp_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Python interface for c++ extensions"""
from .activation import *
from .attention import *
from .gemm import *
from .normalization import *
from .quantization import *
from .softmax import *
Expand Down
Loading

0 comments on commit 6277d22

Please sign in to comment.