Skip to content

Commit

Permalink
updated FWD/BWD wrappers for non-FP8 and FP8 gemm
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Dec 5, 2024
1 parent 13a8cd4 commit b1b51c3
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 117 deletions.
1 change: 1 addition & 0 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,7 @@ def fp8_gemm_impl(
bias: Optional[ArrayLike] = None,
gelu_input: Optional[ArrayLike] = None,
out: Optional[ArrayLike] = None,
extra_out: Optional[ArrayLike] = None,
out_amax: Optional[ArrayLike] = None,
out_scale: Optional[ArrayLike] = None,
out_dtype: jnp.dtype = jnp.bfloat16,
Expand Down
Loading

0 comments on commit b1b51c3

Please sign in to comment.