Skip to content

Commit

Permalink
Set precision as an static arg
Browse files Browse the repository at this point in the history
Signed-off-by: Ming Huang <[email protected]>
  • Loading branch information
mingxu1067 committed Dec 15, 2023
1 parent ca3f0a5 commit 2ce5724
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion transformer_engine/jax/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def dequantize(x, dq_dtype, scale_inv):


# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(4, 5))
@partial(jax.jit, static_argnums=(4, 5, 6))
def fp8_dot_impl(
q_lhs: jnp.ndarray,
q_rhs: jnp.ndarray,
Expand Down

0 comments on commit 2ce5724

Please sign in to comment.