diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9bf3f9fa91..355f587265 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -425,19 +425,16 @@ def _generate_inputs(b, m, n, k, dtype): a = jax.random.normal(subkeys[0], (b, m, k), dtype) b = jax.random.normal(subkeys[1], (n, k), dtype) bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16 - bias = jax.random.normal(subkeys[2], (n, ), bias_dtype) + bias = jax.random.normal(subkeys[2], (n,), bias_dtype) return a, b, bias @staticmethod def _generate_fp8_inputs(b, m, n, k, fp8_dtype): a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16) - a_scale, b_scale = map( - lambda x: (jnp.max(jnp.abs(x)) / 127.).astype(jnp.float32), - [a, b] - ) + a_scale, b_scale = map(lambda x: (jnp.max(jnp.abs(x)) / 127.0).astype(jnp.float32), [a, b]) a_q, b_q = map( lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype), - [(a, a_scale), (b, b_scale)] + [(a, a_scale), (b, b_scale)], ) return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias @@ -447,7 +444,7 @@ def _generate_fp8_inputs(b, m, n, k, fp8_dtype): def test_gemm(self, b, m, n, k, use_bias, do_gelu): a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) - primitive_out = gemm(a, b, bias=bias if use_bias else None, layout='NT', do_gelu=do_gelu) + primitive_out = gemm(a, b, bias=bias if use_bias else None, layout="NT", do_gelu=do_gelu) ref_out = jnp.dot(a, b) if use_bias: ref_out += bias @@ -460,9 +457,7 @@ def test_gemm(self, b, m, n, k, use_bias, do_gelu): @pytest.mark.parametrize("m,n,k", GEMM_CASES) @pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE) def test_fp8_gemm(self, m, n, k, fp8_dtype): - a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs( - m, n, k, fp8_dtype - ) + a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(m, n, k, fp8_dtype) primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16) ref_out = jnp.dot(a, b) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 677fabca59..7bdbb84f84 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -58,9 +58,23 @@ class CollectiveGemmPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_aval, - gelu_input_aval, out_amax_aval, out_scale_aval, out_dtype, contracting_dims, - fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): + def abstract( + lhs_aval, + lhs_scale_inv_aval, + rhs_aval, + rhs_scale_inv_aval, + bias_aval, + gelu_input_aval, + out_amax_aval, + out_scale_aval, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): """ cuBlasLt GEMM abstract """ @@ -90,7 +104,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] @@ -98,8 +112,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == 1 - assert ( - not (lhs_trans and rhs_trans) + assert not ( + lhs_trans and rhs_trans ), "GEMM does not support transposed LHS and transposed RHS at the same time." if is_fp8: assert lhs_trans, "FP8 GEMM does not support transposed LHS." @@ -107,8 +121,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Validate output dtype if jax_dtype_is_fp8(out_dtype): - assert ( - jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8(rhs_dtype) + assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8( + rhs_dtype ), "FP8 GEMM output requires FP8 inputs." assert ( out_amax_aval.size == out_scale_aval.size == 1 @@ -126,8 +140,9 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Infer output shape rhs_outer_dim = 0 if rhs_trans else 1 lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - lhs_bdims = [dim for dim in range(lhs_aval.ndim) - if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_bdims = [ + dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] + ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) @@ -135,9 +150,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: assert ( - bias_aval.size > 0 - and bias_aval.ndim == 1 - and bias_aval.shape[0] == out_shape[-1] + bias_aval.size > 0 and bias_aval.ndim == 1 and bias_aval.shape[0] == out_shape[-1] ), "Incorrect bias shape." bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) else: @@ -145,8 +158,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Validate GELU input/output if fuse_gelu: - assert ( - all([gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)]) + assert all( + [gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)] ), "Invalid GELU input shape." assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." else: @@ -154,14 +167,17 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Create abstract arrays for all outputs out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) - out_amax_updated_aval = out_amax_aval.update(shape=out_amax_aval.shape, - dtype=out_amax_updated_dtype) - out_scale_updated_aval = out_scale_aval.update(shape=out_scale_aval.shape, - dtype=out_scale_updated_dtype) + out_amax_updated_aval = out_amax_aval.update( + shape=out_amax_aval.shape, dtype=out_amax_updated_dtype + ) + out_scale_updated_aval = out_scale_aval.update( + shape=out_scale_aval.shape, dtype=out_scale_updated_dtype + ) pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_input_aval.shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - workspace_aval = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), - dtype=jnp.uint8) + workspace_aval = jax.core.ShapedArray( + shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 + ) return ( out_aval, @@ -169,7 +185,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av out_scale_updated_aval, pre_gelu_out_aval, bias_grad_aval, - workspace_aval + workspace_aval, ) @staticmethod @@ -177,20 +193,31 @@ def outer_abstract(*args, **kwargs): """ cuBlasLt GEMM outer abstract """ - ( - out_aval, - out_amax_aval, - out_scale_aval, - pre_gelu_out_aval, - bias_grad_aval, - _ - ) = CollectiveGemmPrimitive.abstract(*args, **kwargs) + (out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval, _) = ( + CollectiveGemmPrimitive.abstract(*args, **kwargs) + ) return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval @staticmethod - def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, - *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator): + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + *, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): """ Fused attention fwd lowering rules """ @@ -198,7 +225,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == 1 @@ -228,7 +255,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ fuse_bias=fuse_bias, grad=grad, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) else: operands = [ @@ -250,8 +277,9 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ rhs_outer_dim = 0 if rhs_trans else 1 lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - lhs_bdims = [dim for dim in range(lhs_aval.ndim) - if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_bdims = [ + dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] + ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] m = reduce(operator.mul, lhs_batch_shape, 1) * lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] @@ -259,10 +287,22 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ workspace_size = get_cublas_workspace_size_bytes() operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) - opaque = tex.pack_gemm_descriptor(m, n, k, workspace_size, operand_dtype, - jax_dtype_to_te_dtype(out_dtype), bias_dtype, - lhs_trans, rhs_trans, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator) + opaque = tex.pack_gemm_descriptor( + m, + n, + k, + workspace_size, + operand_dtype, + jax_dtype_to_te_dtype(out_dtype), + bias_dtype, + lhs_trans, + rhs_trans, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ) return custom_caller( CollectiveGemmPrimitive.name, @@ -273,9 +313,23 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ ) @staticmethod - def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, - out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator): + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): assert CollectiveGemmPrimitive.inner_primitive is not None ( @@ -305,8 +359,18 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator): + def batcher( + batched_args, + batch_dims, + *, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): assert CollectiveGemmPrimitive.outer_primitive is not None lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale = batched_args @@ -316,7 +380,7 @@ def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == 1 @@ -368,23 +432,28 @@ def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, # Reshape output to recover original LHS batch shape outputs[0] = jnp.reshape( - outputs[0], - (*batch_shape, lhs_outer_size, rhs.shape[rhs_outer_dim]) + outputs[0], (*batch_shape, lhs_outer_size, rhs.shape[rhs_outer_dim]) ) gelu_bdims = batch_dims[3] if fuse_gelu: outputs[3] = jnp.reshape(outputs[3], outputs[0].shape) gelu_bdims = lhs_bdims - return ( - outputs, - (lhs_bdims, batch_dims[1], batch_dims[2], gelu_bdims, batch_dims[4]) - ) + return (outputs, (lhs_bdims, batch_dims[1], batch_dims[2], gelu_bdims, batch_dims[4])) @staticmethod - def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator, mesh, arg_infos, - result_infos): + def infer_sharding_from_operands( + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): del out_dtype, accumulate, use_split_accumulator, result_infos lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) @@ -392,12 +461,14 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: - warnings.warn("Forcing the inner dimension of LHS to match the sharding of inner " - + "dimension of RHS. This can trigger additional communication if LHS is " - + "not already partitioned correctly.") + warnings.warn( + "Forcing the inner dimension of LHS to match the sharding of inner " + + "dimension of RHS. This can trigger additional communication if LHS is " + + "not already partitioned correctly." + ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == 1 @@ -427,8 +498,18 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @staticmethod - def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator, mesh, arg_infos, result_infos): + def partition( + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): del result_infos lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) @@ -436,7 +517,7 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 @@ -470,13 +551,27 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat gelu_spec = out_spec if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) - arg_shardings = (lhs_sharding, fp8_meta_sharding, rhs_sharding, fp8_meta_sharding, - bias_sharding, gelu_sharding, fp8_meta_sharding, fp8_meta_sharding) - out_shardings = (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, - bias_sharding) + arg_shardings = ( + lhs_sharding, + fp8_meta_sharding, + rhs_sharding, + fp8_meta_sharding, + bias_sharding, + gelu_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + ) + out_shardings = ( + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + gelu_sharding, + bias_sharding, + ) - def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, - out_scale): + def sharded_impl( + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale + ): ( out, out_amax_updated, @@ -509,8 +604,7 @@ def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_a # GEMM output needs to be all-reduced when the contracting dimension is sharded. # If the layer is sequence-parallel, we also need to scatter the output, which we # can combine into a reduce-scatter here. - out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, - mesh) + out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( pre_gelu_out, jax.lax.psum, global_mesh_resource().cp_resource, mesh @@ -529,10 +623,10 @@ def fp8_gemm_impl( lhs_scale_inv: ArrayLike, rhs: ArrayLike, rhs_scale_inv: ArrayLike, - bias: Optional[ArrayLike] = None, + bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, - out_amax: Optional[ArrayLike] = None, - out_scale: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, @@ -550,9 +644,7 @@ def fp8_gemm_impl( if not fuse_bias: bias = jnp.zeros(0, dtype=jnp.bfloat16) else: - assert ( - bias is not None - ), "Missing bias in forward GEMM when bias epilogue is enabled." + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) @@ -586,8 +678,8 @@ def fp8_gemm_impl( def gemm_impl( lhs: ArrayLike, rhs: ArrayLike, - bias: Optional[ArrayLike] = None, - gelu_input: Optional[ArrayLike] = None, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, contracting_dims: Tuple[int, int] = (1, 0), fuse_gelu: bool = False, fuse_bias: bool = False, @@ -607,9 +699,7 @@ def gemm_impl( elif grad: bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) else: - assert ( - bias is not None - ), "Missing bias in forward GEMM when bias epilogue is enabled." + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=lhs.dtype) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f60ae510df..5dae9d6757 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -29,8 +29,8 @@ void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_sha auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); void *bias_ptr = (fuse_bias) ? bias : nullptr; - std::vector bias_shape = (fuse_bias) ? std::vector{out_shape[1]} - : std::vector{0}; + std::vector bias_shape = + (fuse_bias) ? std::vector{out_shape[1]} : std::vector{0}; auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; @@ -65,12 +65,9 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque auto *workspace = buffers[13]; // Operand aliasing - NVTE_CHECK(bias == bias_grad, - "bias not bound to bias_grad in TE/JAX GEMM"); - NVTE_CHECK(gelu_input == pre_gelu_out, - "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); - NVTE_CHECK(out_amax == out_amax_updated, - "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(bias == bias_grad, "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input == pre_gelu_out, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax == out_amax_updated, "out_amax not bound to out_amax_updated in TE/JAX GEMM"); NVTE_CHECK(out_scale == out_scale_updated, "out_scale not bound to out_scale_updated in TE/JAX GEMM"); @@ -117,8 +114,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto workspace_size = workspace->dimensions().back(); // Operand aliasing - NVTE_CHECK(bias_ptr == bias_grad_ptr, - "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX GEMM"); NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 7312aa8295..abe23fdf8b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -362,8 +362,16 @@ def generate_a_set(target_postfix): grad_amax, grad_scale = generate_a_set(grad_name_post_fix) output_amax, output_scale = generate_a_set(output_name_post_fix) - return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax, - grad_scale, output_amax, output_scale) + return FP8MetaPackage( + input_amax, + input_scale, + weight_amax, + weight_scale, + grad_amax, + grad_scale, + output_amax, + output_scale, + ) class DenseGeneral(TransformerEngineBase): diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index ccd109e095..79499725b7 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -21,7 +21,6 @@ ) - __all__ = [ "gemm", "fp8_gemm", @@ -52,8 +51,9 @@ def _gemm( accumulate: bool, use_split_accumulator: bool, ) -> ArrayLike: - out, _ = _gemm_fwd_rule(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, - use_split_accumulator) + out, _ = _gemm_fwd_rule( + x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator + ) return out @@ -76,7 +76,7 @@ def _gemm_fwd_rule( fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) ctx = ( @@ -145,8 +145,18 @@ def fp8_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - return _fp8_gemm(x, kernel, bias, fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + return _fp8_gemm( + x, + kernel, + bias, + fp8_meta.amax_list, + fp8_meta.scale_list, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) @@ -163,8 +173,18 @@ def _fp8_gemm( use_split_accumulator: bool, ) -> ArrayLike: """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" - out, _ = _fp8_gemm_fwd_rule(x, kernel, bias, amax_list, scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + out, _ = _fp8_gemm_fwd_rule( + x, + kernel, + bias, + amax_list, + scale_list, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) return out @@ -183,7 +203,8 @@ def _fp8_gemm_fwd_rule( fuse_bias = bias is not None maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( - *amax_list, *scale_list, + *amax_list, + *scale_list, ) amax_list = maybe_fm32_to_fp32(*amax_list) scale_list = maybe_fm32_to_fp32(*scale_list) @@ -272,7 +293,7 @@ def _fp8_gemm_fwd_rule( fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: updated_out_amax = None @@ -288,7 +309,7 @@ def _fp8_gemm_fwd_rule( updated_kernel_amax, pre_gelu_out if fuse_gelu else None, fuse_bias, - maybe_fp32_to_fm32 + maybe_fp32_to_fm32, ) return (out, updated_out_amax, updated_out_scale), ctx @@ -313,7 +334,7 @@ def _fp8_gemm_bwd_rule( updated_kernel_amax, pre_gelu_out, fuse_bias, - maybe_fp32_to_fm32 + maybe_fp32_to_fm32, ) = ctx fwd_dtype = FP8Helper.FWD_DTYPE @@ -347,8 +368,6 @@ def _fp8_gemm_bwd_rule( ) bgrad = None - - x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] wgrad, *_ = fp8_gemm_impl( casted_x, @@ -370,11 +389,11 @@ def _fp8_gemm_bwd_rule( bwd_dtype, static_axis_boundary=-1, transpose_axis_boundary=-1, - activation_type=("gelu", ), + activation_type=("gelu",), ) elif fuse_gelu: # No bias to fuse so we just do dGELU. - casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu", )) + casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) bgrad = None kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] @@ -414,12 +433,23 @@ def type_safe_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] - or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): + if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ]: assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: - return fp8_gemm(x, kernel, bias, fp8_meta, out_dtype, contracting_dims, fuse_gelu, - accumulate, use_split_accumulator) + return fp8_gemm( + x, + kernel, + bias, + fp8_meta, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) else: return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator)