Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 14, 2024
1 parent 941f5bb commit 6444211
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 124 deletions.
15 changes: 5 additions & 10 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,16 @@ def _generate_inputs(b, m, n, k, dtype):
a = jax.random.normal(subkeys[0], (b, m, k), dtype)
b = jax.random.normal(subkeys[1], (n, k), dtype)
bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16
bias = jax.random.normal(subkeys[2], (n, ), bias_dtype)
bias = jax.random.normal(subkeys[2], (n,), bias_dtype)
return a, b, bias

@staticmethod
def _generate_fp8_inputs(b, m, n, k, fp8_dtype):
a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16)
a_scale, b_scale = map(
lambda x: (jnp.max(jnp.abs(x)) / 127.).astype(jnp.float32),
[a, b]
)
a_scale, b_scale = map(lambda x: (jnp.max(jnp.abs(x)) / 127.0).astype(jnp.float32), [a, b])
a_q, b_q = map(
lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype),
[(a, a_scale), (b, b_scale)]
[(a, a_scale), (b, b_scale)],
)
return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias

Expand All @@ -447,7 +444,7 @@ def _generate_fp8_inputs(b, m, n, k, fp8_dtype):
def test_gemm(self, b, m, n, k, use_bias, do_gelu):
a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16)

primitive_out = gemm(a, b, bias=bias if use_bias else None, layout='NT', do_gelu=do_gelu)
primitive_out = gemm(a, b, bias=bias if use_bias else None, layout="NT", do_gelu=do_gelu)
ref_out = jnp.dot(a, b)
if use_bias:
ref_out += bias
Expand All @@ -460,9 +457,7 @@ def test_gemm(self, b, m, n, k, use_bias, do_gelu):
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE)
def test_fp8_gemm(self, m, n, k, fp8_dtype):
a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(
m, n, k, fp8_dtype
)
a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(m, n, k, fp8_dtype)

primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16)
ref_out = jnp.dot(a, b)
Expand Down
Loading

0 comments on commit 6444211

Please sign in to comment.