Skip to content

Commit

Permalink
restored old test_custom_call_compute.py to remove erroneous changes
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Dec 5, 2024
1 parent 39bd494 commit 07a2fb3
Showing 1 changed file with 0 additions and 50 deletions.
50 changes: 0 additions & 50 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
_jax_dbias_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax.gemm import fp8_gemm, gemm
from transformer_engine.jax import cpp_extensions as tex


Expand Down Expand Up @@ -416,55 +415,6 @@ def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_
)


class TestGemm:

@staticmethod
def _generate_inputs(b, m, n, k, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
a = jax.random.normal(subkeys[0], (b, m, k), dtype)
b = jax.random.normal(subkeys[1], (n, k), dtype)
bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16
bias = jax.random.normal(subkeys[2], (n,), bias_dtype)
return a, b, bias

@staticmethod
def _generate_fp8_inputs(b, m, n, k, fp8_dtype):
a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16)
a_scale, b_scale = map(lambda x: (jnp.max(jnp.abs(x)) / 127.0).astype(jnp.float32), [a, b])
a_q, b_q = map(
lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype),
[(a, a_scale), (b, b_scale)],
)
return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias

@pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize("use_bias", (False, True))
@pytest.mark.parametrize("do_gelu", (False, True))
def test_gemm(self, b, m, n, k, use_bias, do_gelu):
a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16)

primitive_out = gemm(a, b, bias=bias if use_bias else None, layout="NT", do_gelu=do_gelu)
ref_out = jnp.dot(a, b)
if use_bias:
ref_out += bias
if do_gelu:
ref_out = jax.nn.gelu(ref_out)

assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE)
def test_fp8_gemm(self, m, n, k, fp8_dtype):
a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(m, n, k, fp8_dtype)

primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16)
ref_out = jnp.dot(a, b)

assert_allclose(primitive_out, ref_out, dtype=fp8_dtype)


@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
key = jax.random.PRNGKey(0)
Expand Down

0 comments on commit 07a2fb3

Please sign in to comment.