We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8fe3942 commit adf4046Copy full SHA for adf4046
transformer_engine/jax/cpp_extensions/gemm.py
@@ -1074,6 +1074,7 @@ def fp8_gemm_impl(
1074
bias: Optional[ArrayLike] = None,
1075
gelu_input: Optional[ArrayLike] = None,
1076
out: Optional[ArrayLike] = None,
1077
+ extra_out: Optional[ArrayLike] = None,
1078
out_amax: Optional[ArrayLike] = None,
1079
out_scale: Optional[ArrayLike] = None,
1080
out_dtype: jnp.dtype = jnp.bfloat16,
0 commit comments