diff --git a/README.rst b/README.rst index ba6bd0e112..65d7aba519 100644 --- a/README.rst +++ b/README.rst @@ -126,8 +126,6 @@ Flax for _ in range(10): loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp) - # Update FP8 metas - other_variables = te.update_fp8_metas(other_grads) .. overview-end-marker-do-not-remove diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 08eb75643a..9bcfc1b98c 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -58,20 +58,18 @@ def __call__(self, x, mask, disable_dropout=False): x = te_flax.DenseGeneral(features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), bias_axes=(NAMED_TP_AXIS,), - sharding_type=te.ShardingType.DP_TP_COL, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256, kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), bias_axes=(NAMED_BROADCAST_AXIS,), - sharding_type=te.ShardingType.DP_TP_ROW, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) return x -def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): +def train_step(state, inputs, masks, labels, var_collect, rngs): """Computes gradients, loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): @@ -87,13 +85,11 @@ def loss_fn(var_collect, disable_dropout=False): var_collect, grads = flax.core.pop(grads, PARAMS_KEY) state = state.apply_gradients(grads=grads) - if use_fp8: - var_collect = te.update_fp8_metas(var_collect) return state, loss, accuracy, var_collect -def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn): +def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): """Train for a single epoch.""" train_ds_size = len(train_ds['sentence']) steps_per_epoch = train_ds_size // batch_size @@ -108,7 +104,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f batch_masks = train_ds['mask'][perm, ...] batch_labels = train_ds['label'][perm, ...] state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, - batch_labels, var_collect, rngs, use_fp8) + batch_labels, var_collect, rngs) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -206,9 +202,8 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} - assert "Float8" in str( - jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, - rngs, True)) + assert "fp8_" in str( + jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) def get_params_pspec(sharding_rules, abs_var_collect): @@ -269,7 +264,8 @@ def train_and_evaluate(args): label_shape = [args.batch_size] with te.fp8_autocast(args.use_fp8, - sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, + None)): encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) @@ -297,7 +293,7 @@ def train_and_evaluate(args): in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,)) + pjit_train_step = pjit(train_step, in_shardings, out_shardings) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) out_shardings = (None, None) @@ -310,7 +306,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) + pjit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None @@ -320,8 +316,7 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8, - pjit_train_step) + state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect, pjit_eval_step) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 73879ad10a..a9bef4a99f 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -52,17 +52,15 @@ def __call__(self, x, mask, disable_dropout=False): x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, - dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) - x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, - dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) return x -def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): +def train_step(state, inputs, masks, labels, var_collect, rngs): """Computes gradients, loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): @@ -78,13 +76,11 @@ def loss_fn(var_collect, disable_dropout=False): var_collect, grads = flax.core.pop(grads, PARAMS_KEY) state = state.apply_gradients(grads=grads) - if use_fp8: - var_collect = te.update_fp8_metas(var_collect) return state, loss, accuracy, var_collect -def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn): +def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): """Train for a single epoch.""" train_ds_size = len(train_ds['sentence']) steps_per_epoch = train_ds_size // batch_size @@ -99,7 +95,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f batch_masks = train_ds['mask'][perm, ...] batch_labels = train_ds['label'][perm, ...] state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, - batch_labels, var_collect, rngs, use_fp8) + batch_labels, var_collect, rngs) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -197,9 +193,8 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} - assert "Float8" in str( - jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, - rngs, True)) + assert "fp8_" in str( + jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) def get_params_pspec(sharding_rules, abs_var_collect): @@ -252,7 +247,8 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast(args.use_fp8, sharding_resource=te.ShardingResource(DEVICE_DP_AXIS)): + with te.fp8_autocast(args.use_fp8, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)): encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) @@ -279,7 +275,7 @@ def train_and_evaluate(args): in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,)) + pjit_train_step = pjit(train_step, in_shardings, out_shardings) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) out_shardings = (None, None) @@ -292,7 +288,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) + pjit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None @@ -302,8 +298,7 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8, - pjit_train_step) + state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect, pjit_eval_step) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index e94d5e20f5..08390d9253 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -61,13 +61,11 @@ def __call__(self, x, mask, disable_dropout=False): x = te_flax.DenseGeneral(features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), bias_axes=(NAMED_TP_AXIS,), - sharding_type=te.ShardingType.DP_TP_COL, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256, kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), bias_axes=(NAMED_BROADCAST_AXIS,), - sharding_type=te.ShardingType.DP_TP_ROW, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) @@ -106,7 +104,7 @@ def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False return global_input_shape, named_sharding, inputs -def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): +def train_step(state, inputs, masks, labels, var_collect, rngs): """Computes gradients, loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): @@ -122,14 +120,12 @@ def loss_fn(var_collect, disable_dropout=False): var_collect, grads = flax.core.pop(grads, PARAMS_KEY) state = state.apply_gradients(grads=grads) - if use_fp8: - var_collect = te.update_fp8_metas(var_collect) return state, loss, accuracy, var_collect -def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn, mesh, - inputs_pspec, masks_pspec, labels_pspec): +def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn, mesh, inputs_pspec, + masks_pspec, labels_pspec): """Train for a single epoch.""" total_batch_size = len(train_ds['sentence']) @@ -164,7 +160,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f label_named_sharding, [batch_label]) state, loss, accuracy, var_collect = train_fn(state, shard_input, shard_mask, shard_label, - var_collect, rngs, use_fp8) + var_collect, rngs) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -280,9 +276,8 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} - assert "Float8" in str( - jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, - rngs, True)) + assert "fp8_" in str( + jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) def get_params_pspec(sharding_rules, abs_var_collect): @@ -350,7 +345,8 @@ def train_and_evaluate(args): label_shape = [args.batch_size] with te.fp8_autocast(args.use_fp8, - sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, + None)): encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) @@ -378,7 +374,7 @@ def train_and_evaluate(args): in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,)) + pjit_train_step = pjit(train_step, in_shardings, out_shardings) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) out_shardings = (None, None) @@ -391,7 +387,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) + pjit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") else: for epoch in range(1, args.epochs + 1): @@ -400,8 +396,8 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8, - pjit_train_step, shard_mesh, inputs_pspec, masks_pspec, labels_pspec) + state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step, + shard_mesh, inputs_pspec, masks_pspec, labels_pspec) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect, pjit_eval_step, shard_mesh, diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index c5d472658b..f752d88659 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -56,7 +56,7 @@ def __call__(self, x, mask, disable_dropout=False): @partial(jax.jit, static_argnums=6) -def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): +def train_step(state, inputs, masks, labels, var_collect, rngs): """Computes gradients, loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): @@ -72,13 +72,11 @@ def loss_fn(var_collect, disable_dropout=False): var_collect, grads = flax.core.pop(grads, PARAMS_KEY) state = state.apply_gradients(grads=grads) - if use_fp8: - var_collect = te.update_fp8_metas(var_collect) return state, loss, accuracy, var_collect -def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): +def train_epoch(state, train_ds, batch_size, rngs, var_collect): """Train for a single epoch.""" train_ds_size = len(train_ds['sentence']) steps_per_epoch = train_ds_size // batch_size @@ -93,7 +91,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): batch_masks = train_ds['mask'][perm, ...] batch_labels = train_ds['label'][perm, ...] state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks, - batch_labels, var_collect, rngs, use_fp8) + batch_labels, var_collect, rngs) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -192,9 +190,8 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} - assert "Float8" in str( - jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, - rngs, True)) + assert "fp8_" in str( + jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) def train_and_evaluate(args): @@ -228,7 +225,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) + train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None @@ -238,7 +235,7 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8) + state, train_ds, args.batch_size, rngs, var_collect) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index df27f3965e..45b674549f 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -75,15 +75,13 @@ def loss_fn(var_collect, disable_dropout=False): @partial(jax.jit, static_argnums=2) -def update_model(state, grads, use_fp8): +def update_model(state, grads): """Update model params and FP8 meta.""" state = state.apply_gradients(grads=grads[PARAMS_KEY]) - if use_fp8: - grads = te.update_fp8_metas(grads) return state, grads -def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): +def train_epoch(state, train_ds, batch_size, rngs, var_collect): """Train for a single epoch.""" train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size @@ -97,7 +95,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): batch_images = train_ds['image'][perm, ...] batch_labels = train_ds['label'][perm, ...] grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs) - state, var_collect = update_model(state, grads, use_fp8) + state, var_collect = update_model(state, grads) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -150,7 +148,7 @@ def get_datasets(): def check_fp8(state, var_collect, input_shape, label_shape): "Check if model includes FP8." - assert "Float8" in str( + assert "f8_" in str( jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16), jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect)) @@ -195,7 +193,7 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8) + state, train_ds, args.batch_size, rngs, var_collect) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) print(f"Epoch: {epoch:>2} " diff --git a/tests/jax/distributed_configs_helper.py b/tests/jax/distributed_configs_helper.py deleted file mode 100644 index b8d9ccb357..0000000000 --- a/tests/jax/distributed_configs_helper.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -import jax -from itertools import product -from transformer_engine.jax.sharding import ShardingType -from transformer_engine.jax.softmax import SoftmaxType -from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType\ - - -class DistributedConfigsHelper(object): - - def __init__(self, num_gpus=len(jax.devices())): - super().__init__() - self.device_count = min(num_gpus, 8) - if self.device_count < 2: - self.layernorm_refs = [] - self.softmax_types = [] - self.softmax_refs = [] - self.self_attn_bias_types = [] - self.self_attn_mask_types = [] - self.self_attn_refs = [] - self.cross_attn_mask_types = [] - self.cross_attn_refs = [] - return - - mesh_configs = [ - ((self.device_count, 1), ("dp", None), ShardingType.DP), - ((self.device_count, 1), ("tp", None), ShardingType.TP_COL), - ((self.device_count, 1), ("tp", None), ShardingType.TP_ROW) - ] - if self.device_count >= 4: - mesh_configs += [ - ((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_COL), - ((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW), - ] - if self.device_count >= 6: - mesh_configs += [ - ((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_COL), - ((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_ROW), - ] - - layernorm_collectives = { - ShardingType.DP : {'all-reduce': 2, 'other': 0}, - ShardingType.TP_COL : {'all-reduce': 0, 'other': 0}, - ShardingType.DP_TP_COL : {'all-reduce': 2, 'other': 0} - } - self.layernorm_refs = [ - mesh_config + (layernorm_collectives[mesh_config[2]], ) \ - for mesh_config in mesh_configs \ - if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) - ] - - self.softmax_types = [ - SoftmaxType.SCALED, - SoftmaxType.SCALED_MASKED, - SoftmaxType.SCALED_UPPER_TRIANG_MASKED - ] - softmax_collectives = { - ShardingType.DP : {'all-reduce': 1, 'other': 0}, - ShardingType.TP_COL : {'all-reduce': 1, 'other': 0}, - ShardingType.TP_ROW : {'all-reduce': 1, 'other': 0}, - ShardingType.DP_TP_COL : {'all-reduce': 1, 'other': 0}, - ShardingType.DP_TP_ROW : {'all-reduce': 1, 'other': 0} - } - self.softmax_refs = [ - mesh_config + (softmax_collectives[mesh_config[2]], ) for mesh_config in mesh_configs - ] - - self.self_attn_bias_types = [ - AttnBiasType.NO_BIAS, - AttnBiasType.PRE_SCALE_BIAS, - AttnBiasType.POST_SCALE_BIAS - ] - self.self_attn_mask_types = [ - AttnMaskType.CAUSAL_MASK, - AttnMaskType.PADDING_MASK, - AttnMaskType.NO_MASK - ] - self_attn_collectives = { - ShardingType.DP : { - AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0}, - AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 2, 'other': 0}, - AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 2, 'other': 0}, - }, - ShardingType.TP_COL : { - AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0}, - AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 1, 'other': 0}, - AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 1, 'other': 0} - }, - ShardingType.DP_TP_COL : { - AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0}, - AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 2, 'other': 0}, - AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 2, 'other': 0} - }, - } - self.self_attn_refs = [ - mesh_config + (bias_type, self_attn_collectives[mesh_config[2]][bias_type]) \ - for mesh_config, bias_type in product(mesh_configs, self.self_attn_bias_types) \ - if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) - ] - - self.cross_attn_mask_types = [ - AttnMaskType.PADDING_MASK, - AttnMaskType.NO_MASK - ] - self.cross_attn_refs = [ - mesh_config + ({'all-reduce': 1, 'other': 0}, ) \ - for mesh_config in mesh_configs \ - if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) - ] diff --git a/tests/jax/distributed_ops_helper.py b/tests/jax/distributed_ops_helper.py deleted file mode 100644 index e39ebe565a..0000000000 --- a/tests/jax/distributed_ops_helper.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -import os -import pytest -import numpy as np -from dataclasses import dataclass -from typing import Tuple -from enum import Enum -from functools import partial - -import jax -import jax.numpy as jnp -from jax import random -from jax.experimental.pjit import pjit, _UNSPECIFIED -from jax.sharding import PartitionSpec - -import flax - -from transformer_engine.jax.sharding import ShardingType -try: - # try importing the new custom partitioning implementation - from transformer_engine.jax.sharding import MeshResource -except ImportError: - # must be using an older TE/JAX version so fall back on the xmap sharding implementation - MeshResource = None -from transformer_engine.jax.layernorm import layernorm -from transformer_engine.jax.softmax import SoftmaxType, softmax -from transformer_engine.jax.fused_attn import \ - AttnBiasType, AttnMaskType, is_fused_attn_kernel_available, self_fused_attn, cross_fused_attn - - -class FusedAttnBackend(Enum): - Max512 = "0" - Arbitrary = "1" - - -@pytest.fixture(name="backend", params=[FusedAttnBackend.Max512, FusedAttnBackend.Arbitrary]) -def fixture_backend(request): - backend = request.param - os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value - yield backend - os.environ["NVTE_FUSED_ATTN_BACKEND"] = "" - - -@dataclass -class DistributedOpsHelper: - qkv_shape: Tuple[int,int,int,int] = (32, 128, 16, 64) - pad_ratio: float = 0.3 - dropout_prob: float = 0.1 - dtype: type = jnp.float16 - - @staticmethod - def use_custom_partitioning(): - return (MeshResource is not None) - - @staticmethod - def get_sharding_spec(mesh_names, sharding_type): - P = PartitionSpec - if sharding_type is ShardingType.DP: - return P(mesh_names[0], None), P(None), P(None) - elif sharding_type is ShardingType.DP_TP_COL: - return P(mesh_names[0], mesh_names[1]), P(None), P(None) - else: - raise NotImplementedError - - @staticmethod - def get_sharding_resource(mesh_names, sharding_type): - dp_r = None - tp_r = None - - if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW): - dp_r = mesh_names[0] - if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW): - tp_r = mesh_names[0] - if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW): - tp_r = mesh_names[1] - - return MeshResource(dp_r, tp_r) - - @staticmethod - def make_mask(q_tokens, kv_tokens, mask_type, dtype=jnp.uint8): - if mask_type == AttnMaskType.CAUSAL_MASK: - causal = flax.linen.make_causal_mask(q_tokens, dtype=dtype) - padding = flax.linen.make_attention_mask(q_tokens > 0, kv_tokens > 0, dtype=dtype) - return flax.linen.combine_masks(causal, padding) - else: - return flax.linen.make_attention_mask(q_tokens > 0, kv_tokens > 0, dtype=dtype) - - @staticmethod - def count_collectives(hlo): - tmp = hlo.splitlines() - symb = "-start" - result = { - "all-reduce" : 0, - "other" : 0 - } - for line in tmp: - txt = line.split() - if len(txt) > 0 and symb in txt[0]: - if "all-reduce" in txt[0]: - result["all-reduce"] += 1 - else: - result["other"] += 1 - return result - - def get_tolerance(self, ref_val, relaxation=2./3., dtype=None): - if dtype is None: - dtype = self.dtype - - # slightly relax the machine epsilon for minimum tolerance - eps_relaxed = jax.lax.pow(jnp.finfo(dtype).eps, dtype(relaxation)) - - # calculate the "Unit of Least Precision" -- i.e. distance to the next representable number - spacing_high = jnp.nextafter(dtype(ref_val), jnp.finfo(dtype).max) - dtype(ref_val) - spacing_low = dtype(ref_val) - jnp.nextafter(dtype(ref_val, jnp.finfo(dtype).min)) - ulp = jax.lax.max(spacing_low, spacing_high) - - return jax.lax.max(eps_relaxed, ulp) - - def compare_ops(self, custom_func, ref_func, ref_count, - *args, grad_args=None, dtype=None, - in_shardings=_UNSPECIFIED, out_shardings=_UNSPECIFIED, - **kwargs): - if dtype is None: - dtype = self.dtype - - if isinstance(custom_func, partial): - func_name = custom_func.func.__name__ - else: - func_name = custom_func.__name__ - func_name = func_name.removeprefix('custom_') - if grad_args is None: - grad_args = tuple(range(len(args))) - - custom_gradded = jax.value_and_grad(custom_func, argnums=grad_args) - test_fwd, test_grads = custom_gradded(*args, **kwargs) - custom_pjitter = pjit(custom_gradded, - in_shardings=in_shardings, - out_shardings=out_shardings) - custom_hlo = custom_pjitter.lower(*args, **kwargs).compile().as_text() - custom_count = self.count_collectives(custom_hlo) - if ref_count is not None: - assert custom_count==ref_count, \ - f"`{func_name}`: Expected collective count is {ref_count}, but got {custom_count}." - else: - print(f"`{func_name}`: Output collective count is {custom_count}.") - - ref_gradded = jax.value_and_grad(ref_func, argnums=grad_args) - ref_fwd, ref_grads = ref_gradded(*args, **kwargs) - fwd_tol = self.get_tolerance(ref_fwd, dtype=dtype) - assert jnp.allclose(test_fwd, ref_fwd, rtol=0.0, atol=fwd_tol), \ - f"`{func_name}`: Output (fwd) error {jnp.max(jnp.abs(test_fwd - ref_fwd))}" + \ - f" exceeds tolerance ({fwd_tol})." - - if len(grad_args) == 1: - ref_grads = (ref_grads, ) - test_grads = (test_grads, ) - failed_grads = {} - for i, grads in enumerate(zip(test_grads, ref_grads)): - test_grad, ref_grad = grads - if test_grad is None and ref_grad is None: - continue - bwd_tol = self.get_tolerance(ref_grad, dtype=dtype) - if not jnp.allclose(test_grad, ref_grad, rtol=0.0, atol=bwd_tol): - failed_grads[i] = jnp.max(jnp.abs(test_grad - ref_grad)) - assert len(failed_grads) == 0, \ - f"`{func_name}`: Gradient (bwd) max errors" + \ - f" [{', '.join([f'Arg{k}={v}' for k,v in failed_grads.items()])}]" + \ - f" exceed tolerance ({bwd_tol})." - - @staticmethod - def check_fused_attn_inputs(self, q_seq, kv_seq, head_dim, pad_ratio, dropout_probability, - attn_bias_type, attn_mask_type, backend, dtype=jnp.float16): - if (q_seq > 512 or kv_seq > 512 or backend == FusedAttnBackend.Arbitrary) \ - and pad_ratio != 0: - pytest.skip( - "`fused_attention`: Arbitrary seqlen backend does not support padded input.") - - if not is_fused_attn_kernel_available( - dtype, dtype, attn_bias_type, attn_mask_type, - dropout_probability, q_seq, kv_seq, head_dim): - pytest.skip( - "`fused_attention`: Unsupported inputs combination or device compute capability.") - - def fused_attn_core(self, query, key, value, bias, mask, scale_factor, - attn_bias_type, attn_mask_type, dropout_rng, dropout_prob): - # Q*K matmul - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key) - # scale and bias - if attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: - attn_weights = scale_factor * (attn_weights + bias) - elif attn_bias_type == AttnBiasType.POST_SCALE_BIAS: - attn_weights = scale_factor * attn_weights + bias - else: - attn_weights = scale_factor * attn_weights - # padding mask - if attn_mask_type != AttnMaskType.NO_MASK and mask is not None: - big_neg = jnp.finfo(query.dtype).min - attn_weights = jnp.where(mask, attn_weights, big_neg) - # softmax - attn_weights = jax.nn.softmax(attn_weights).astype(query.dtype) - # dropout - if dropout_prob == 1.0: - attn_weights = jnp.zeros_like(attn_weights) - elif dropout_prob > 0.0: - keep_prob = 1.0 - dropout_prob - keep = random.bernoulli(dropout_rng, p=keep_prob, shape=attn_weights.shape) - multiplier = keep.astype(query.dtype) / jnp.asarray(keep_prob, dtype=query.dtype) - attn_weights = attn_weights * multiplier - # QK*V matmul - result = jnp.einsum('...hqk,...khd->...qhd', attn_weights, value) - return jnp.mean(result) - - @staticmethod - def custom_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, sharding_type): - result = layernorm(x, gamma, beta, - layernorm_type='layernorm', - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon, - sharding_type=sharding_type, - dp_dim_index=0) - return jnp.mean(result) - - def reference_layernorm(self, x, gamma, beta, zero_centered_gamma, epsilon): - x_ = jnp.asarray(x, jnp.float32) - mean = jnp.mean(x_, axis=-1, keepdims=True) - var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) - normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon) - if zero_centered_gamma: - result = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype) - else: - result = jnp.asarray(normed_input * gamma + beta).astype(x.dtype) - return jnp.mean(result) - - @staticmethod - def custom_rmsnorm(x, gamma, epsilon, sharding_type): - result = layernorm(x, gamma, None, - layernorm_type='rmsnorm', - zero_centered_gamma=False, - epsilon=epsilon, - sharding_type=sharding_type, - dp_dim_index=0) - return jnp.mean(result) - - def reference_rmsnorm(self, x, gamma, epsilon): - x = jnp.asarray(x, jnp.float32) - mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), x.dtype) - result = y * gamma - return jnp.mean(result) - - @staticmethod - def custom_softmax(x, mask, scale_factor, softmax_type, sharding_type): - result = softmax(x, mask, - scale_factor=scale_factor, - softmax_type=softmax_type, - sharding_type=sharding_type) - return jnp.mean(result) - - def reference_softmax(self, x, mask, scale_factor, softmax_type): - attn_weights = scale_factor * x - if softmax_type != SoftmaxType.SCALED: - big_neg = jnp.finfo(x.dtype).min - attn_weights = jnp.where(mask, attn_weights, big_neg) - result = jax.nn.softmax(attn_weights).astype(x.dtype) - return jnp.mean(result) - - @staticmethod - def custom_self_fused_attn(qkv, bias, mask, rng_key, dropout_prob, - attn_bias_type, attn_mask_type, - scaling_factor, sharding_type): - mask = (mask == 0) # invert mask - bias_ = None if attn_bias_type == AttnBiasType.NO_BIAS else bias - result = self_fused_attn(qkv, bias_, mask, - seed=rng_key, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=True, - sharding_type=sharding_type) - return jnp.mean(result) - - def reference_self_fused_attn(self, qkv, bias, mask, rng_key, dropout_prob, - attn_bias_type, attn_mask_type, - scaling_factor): - # split interleaved QKV into separate matrices - query, key, value = jnp.split(qkv, [1, 2], axis=-3) - return self.fused_attn_core( - query, key, value, bias, mask, scaling_factor, - attn_bias_type, attn_mask_type, - rng_key, dropout_prob) - - @staticmethod - def custom_cross_fused_attn(query, key_value, mask, rng_key, dropout_prob, - attn_mask_type, scaling_factor, sharding_type): - mask = (mask == 0) # invert mask - result = cross_fused_attn(query, key_value, mask, - seed=rng_key, - attn_bias_type=AttnBiasType.NO_BIAS, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=True, - sharding_type=sharding_type) - return jnp.mean(result) - - def reference_cross_fused_attn(self, query, key_value, mask, rng_key, dropout_prob, - attn_mask_type, scaling_factor): - key, value = jnp.split(key_value, [1], axis=-3) - return self.fused_attn_core( - query, key, value, None, mask, scaling_factor, - AttnBiasType.NO_BIAS, attn_mask_type, - rng_key, dropout_prob) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py new file mode 100644 index 0000000000..858d96ab52 --- /dev/null +++ b/tests/jax/distributed_test_base.py @@ -0,0 +1,133 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import operator +import re +from functools import reduce + +import jax +from jax.experimental.pjit import pjit, _UNSPECIFIED + +from transformer_engine.jax.sharding import MeshResource + +from utils import assert_allclose, is_devices_enough + + +def generate_configs(): + configs = [] + if is_devices_enough(2): + configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')]) + configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')]) + + if is_devices_enough(4): + TP_size = 2 + DP_size = 2 + configs.append( + [4, (DP_size, TP_size), ('dp', 'tp'), + MeshResource(dp_resource='dp', tp_resource='tp')]) + + return configs + + +COLL_AR_KEY = "all-reduce" +COLL_AG_KEY = "all-gather" +COLL_OTHER_KEY = "other" + + +def generate_collectives_count(allreduce, allgather, other): + return {COLL_AR_KEY: allreduce, COLL_AG_KEY: allgather, COLL_OTHER_KEY: other} + + +def assert_equal_collectives(target_hlo, coll_count_ref): + target_splitted_hlo = target_hlo.splitlines() + start_symb = "-start" + + def count_bytes(hlo_text): + bytes_count = 0 + + def get_bytes_per_txt(t): + ''' + The pattern of t would be like: + 'f32[]', + '(f32[1024]{0}', + 'f32[1024]{0})', + 'f8E4M3FN[1024]{0}', + 'i32[1024]{0}', + 'bf16[1024,1024]{0}' + ''' + match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t) + _, bits_of_type, shape = match.groups() + bytes_of_type = int(bits_of_type) // 8 + if shape == '': + num_of_elements = 1 + else: + num_of_elements = reduce(operator.mul, map(int, shape.split(','))) + + return bytes_of_type * num_of_elements + + # ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...] + if '(' in hlo_text[2]: + for txt in hlo_text[2:]: + bytes_count += get_bytes_per_txt(txt) + if ')' in txt: + break + else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...] + bytes_count = get_bytes_per_txt(hlo_text[2]) + + return bytes_count + + def count_collectives(splitted_hlo): + result = generate_collectives_count(0, 0, 0) + + for line in splitted_hlo: + txt = line.split() + if len(txt) > 0 and start_symb in txt[0]: + if COLL_AR_KEY in txt[0]: + result[COLL_AR_KEY] += count_bytes(txt) + elif COLL_AG_KEY in txt[0]: + result[COLL_AG_KEY] += count_bytes(txt) + else: + result[COLL_OTHER_KEY] += count_bytes(txt) + return result + + target_result = count_collectives(target_splitted_hlo) + assert target_result == coll_count_ref, \ + f"Expected collective count is {coll_count_ref}, but got {target_result}." + + +def compare_ops(target_func, + ref_func, + inputs, + coll_count_ref, + *, + grad_args=None, + metric_fwd_dtype=None, + metric_bwd_dtype=None, + in_shardings=_UNSPECIFIED, + out_shardings=_UNSPECIFIED, + **kwargs): + assert len(inputs) >= 1 + + if metric_fwd_dtype is None: + metric_fwd_dtype = inputs[0].dtype + if metric_bwd_dtype is None: + metric_bwd_dtype = inputs[0].dtype + + if grad_args is None: + grad_args = tuple(range(len(inputs))) + + target_grad_func = jax.value_and_grad(target_func, argnums=grad_args) + target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) + target_fwd, target_grads = target_pjitter(*inputs, **kwargs) + target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text() + + ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args) + ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) + ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs) + + assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype) + + for i in range(len(target_grads)): + assert_allclose(target_grads[i], ref_grads[i], dtype=metric_bwd_dtype) + + assert_equal_collectives(target_hlo, coll_count_ref) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 05b7bc3603..5349b90e37 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -14,15 +14,13 @@ from flax import linen as nn from utils import assert_allclose -from transformer_engine.common.recipe import Format from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8 -from transformer_engine.jax.cpp_extensions import dequantize, quantize -from transformer_engine.jax.dot import fp8_dot -from transformer_engine.jax.fp8 import DType, FP8GemmPackage, FP8Helper, _format2dtypes +from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize +from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.layernorm import layernorm -from transformer_engine.jax.mlp import fp8_ln_mlp +from transformer_engine.jax.mlp import layernrom_geglu_fp8_mlp GEMM_CASES = [ (256, 256, 512), @@ -31,7 +29,7 @@ (2048, 2048, 1024), (2048, 1024, 1024), ] -FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)] +FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] LN_CASES = [(512, 1024)] DTYPES = [jnp.bfloat16, jnp.float32] is_fp8_supported, reason = is_fp8_available() @@ -51,67 +49,16 @@ class TestFP8Dot: @pytest.mark.skipif(not is_fp8_supported, reason=reason) def test_qdq(self): - FP8_E4M3_MAX = 448 + FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32) x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32) amax = jnp.max(jnp.abs(x)).reshape(1) scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1) scale_inv = (1 / scale).reshape(1) - y, new_amax = quantize(x, amax, scale, scale_inv, out_dtype=DType.kFloat8E4M3) - assert_allclose(new_amax, 3.0, rtol=0, atol=0) + y = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale) + z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv) - no_use = jnp.zeros(1, jnp.float32) - z = dequantize(y, - no_use, - no_use, - scale_inv, - fp8_dtype=DType.kFloat8E4M3, - out_dtype=DType.kFloat32) - assert_allclose(z, x, dtype=DType.kFloat8E4M3) - - def test_compile_bf16(self): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 2) - a = jax.random.normal(subkeys[0], (256, 512), jnp.bfloat16) - b = jax.random.normal(subkeys[1], (512, 256), jnp.bfloat16) - - def func(x, y): - fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) - fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN), - jnp.float32) - fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - # x = input, matrix 2d - # y = input, matrix 2d (weight) - fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv) - return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None))) - - value_n_grad_func = value_and_grad(func, (0, 1)) - value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile() - value_n_grad_func_compiled(a, b) - - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE) - def test_compile_fp8(self, compute_type): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 2) - a = jax.random.normal(subkeys[0], (256, 512), jnp.bfloat16) - b = jax.random.normal(subkeys[1], (512, 256), jnp.bfloat16) - - def func(x, y): - fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) - fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN), - jnp.float32) - fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv) - return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type)) - - value_n_grad_func = value_and_grad(func, (0, 1)) - value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile() - value_n_grad_func_compiled(a, b) + assert_allclose(z, x, dtype=jnp.float8_e4m3fn) @pytest.mark.parametrize('m,n,k', GEMM_CASES) def test_forward_bf16(self, m, n, k): @@ -120,23 +67,14 @@ def test_forward_bf16(self, m, n, k): a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16) - fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) - fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN), - jnp.float32) - fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv) - fwd_dtype, bwd_dtype = _format2dtypes(None) - primitive_out = fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype) + primitive_out = type_safe_dot_general(a, b) ref_out = jnp.dot(a, b) - assert_allclose(primitive_out, ref_out, dtype=fwd_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize('m,n,k', GEMM_CASES) - @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE) - def test_forward_fp8_randint(self, m, n, k, compute_type): + def test_forward_fp8_randint(self, m, n, k): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -150,22 +88,24 @@ def test_forward_fp8_randint(self, m, n, k, compute_type): jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv] + fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv) + + primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg) - # calculate amax - fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta) - primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type) # calculate scale by amax - fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta) + fp8_metas_scale, fp8_metas_scale_inv = FP8Helper.update_fp8_scale( + fp8_max, fp8_metas_amax, fp8_metas_scale) + fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv) - fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta) - primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type) + primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg) ref_out = jnp.dot(a, b) ref_out = ref_out.astype(jnp.float32) primitive_out = primitive_out.astype(jnp.float32) - assert_allclose(primitive_out, ref_out, dtype=compute_type[0]) + assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) @pytest.mark.parametrize('m,n,k', GEMM_CASES) def test_grad_bf16(self, m, n, k): @@ -173,17 +113,10 @@ def test_grad_bf16(self, m, n, k): subkeys = jax.random.split(key, 2) a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16) - fwd_dtype, bwd_dtype = _format2dtypes(None) def primitive_func(x, y): - fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) - fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN), - jnp.float32) - fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv) - return jnp.mean(fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype)) + primitive_out = type_safe_dot_general(x, y) + return jnp.mean(primitive_out) def ref_func(x, y): return jnp.mean(jnp.dot(x, y)) @@ -195,116 +128,75 @@ def ref_func(x, y): primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b) ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b) - assert_allclose(primitive_out, ref_out, dtype=fwd_dtype) - assert_allclose(primitive_a_grad, ref_a_grad, dtype=bwd_dtype) - assert_allclose(primitive_b_grad, ref_b_grad, dtype=bwd_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) + assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16) + assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize('m,n,k', GEMM_CASES) - @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE) - def test_grad_fp8_randint(self, m, n, k, compute_type): + def test_grad_fp8_dot(self, m, n, k): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) - # TODO(rewang): add float random test - min_val, max_val = -8, 8 - a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16) - b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16) + a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) + b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv] - def primitive_func(x, y, metas): - fp8_gemm_pkg = FP8GemmPackage(1, x, [y], *metas) - return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type)) + def primitive_func(x, y, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv): + fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv) + primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg) + return jnp.mean(primitive_out) def ref_func(x, y): - return jnp.sum(jnp.dot(x, y)) + return jnp.mean(jnp.dot(x, y)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1)) + value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1)) ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b) - # calculate amax - primitive_out, (primitive_a_grad, - primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta) + for _ in range(3): + primitive_out, (primitive_a_grad, primitive_b_grad, fp8_max, fp8_metas_amax, + fp8_metas_scale, fp8_metas_scale_inv) = value_n_grad_primitive_func( + a, b, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv) - # calculate scale by amax - fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta) - primitive_out, (primitive_a_grad, - primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta) - - assert_allclose(primitive_out, ref_out, dtype=compute_type[0]) - assert_allclose(primitive_a_grad, ref_a_grad, dtype=compute_type[1]) - assert_allclose(primitive_b_grad, ref_b_grad, dtype=compute_type[1]) - - def test_contracting_dims_bf16(self): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 2) - a = jax.random.normal(subkeys[0], (32, 8, 16, 64), jnp.bfloat16) - b = jax.random.normal(subkeys[1], (16, 64, 128), jnp.bfloat16) - fwd_dtype, bwd_dtype = _format2dtypes(None) - - def primitive_func(x, y): - fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) - fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN), - jnp.float32) - fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) - fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv) - return jnp.sum(fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype, ((2, 3), (0, 1)))) - - def ref_func(x, y): - return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ())))) - - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1)) - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1)) - primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b) - ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b) - - assert_allclose(primitive_out, ref_out, dtype=fwd_dtype) - assert_allclose(primitive_a_grad, ref_a_grad, dtype=bwd_dtype) - assert_allclose(primitive_b_grad, ref_b_grad, dtype=bwd_dtype) + assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) + assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE) + assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]) - def test_grad_fp8_mlp_randint(self, m, n, k): + def test_grad_ln_geglu_fp8_mlp(self, m, n, k): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) activations = ('gelu', 'linear') - a = jax.random.uniform(subkeys[0], (m, k), jnp.bfloat16, 5, 8) - k1 = jax.random.uniform(subkeys[1], (k, n * len(activations)), jnp.bfloat16, 5, 8) - k2 = jax.random.uniform(subkeys[2], (n, k), jnp.bfloat16, 5, 8) - s = jax.random.uniform(subkeys[3], (k,), jnp.bfloat16, 5, 8) - - fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) - fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), - jnp.float32) - fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) - fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) - fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv] - compute_type = _format2dtypes(Format.HYBRID) - - def primitive_func(x, ln_s, y, z, metas): + a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) + k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16) + k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) + s = jax.random.normal(subkeys[3], (k,), jnp.bfloat16) + + init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) + init_fp8_metas_amax = jnp.zeros( + (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32) + init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) + init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) + + def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv): # x is input tensor, matrix 2d # y, z are weights, matrix 2d # out = (x * y) * z - fp8_gemm_pkg = FP8GemmPackage(2, x, [y, z], *metas) - return jnp.mean( - fp8_ln_mlp(fp8_gemm_pkg, - ln_s, - None, - "rmsnorm", - *compute_type, - activations=activations)) + fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv) + return jnp.mean(layernrom_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm")) def _convert_to_activation_function(fn_or_string): """Convert a string to an activation function.""" @@ -316,86 +208,82 @@ def _convert_to_activation_function(fn_or_string): return fn_or_string raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") - def fp8_ln_mlp_py(inputs: jnp.ndarray, - ln_scale: jnp.ndarray, - kernel_1: jnp.ndarray, - kernel_2: jnp.ndarray, - fp8_maxs: jnp.ndarray, - amax: jnp.ndarray, - scale: jnp.ndarray, - scale_inv: jnp.ndarray, - fwd_dtype, - bwd_dtype, - epsilon=1e-6, - contracting_dims=((-1,), (0,)), - dp_dim_index=0, - activations=('gelu', 'linear')) -> jnp.ndarray: - x = jnp.asarray(inputs, jnp.float32) + def ln_geglu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, + kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, + scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray: + + x = jnp.asarray(x, jnp.float32) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), jnp.bfloat16) + y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16) ln_out = y * ln_scale ln_out = jnp.asarray(ln_out, jnp.bfloat16) - fp8_gemm_1_pkg = FP8GemmPackage(1, ln_out, [kernel_1], - fp8_maxs[:FP8Helper.NUM_META_PER_GEMM], + + fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM], amax[:FP8Helper.NUM_META_PER_GEMM], scale[:FP8Helper.NUM_META_PER_GEMM], scale_inv[:FP8Helper.NUM_META_PER_GEMM]) - linear_1_out = fp8_dot(fp8_gemm_1_pkg, - fwd_dtype, - bwd_dtype, - contracting_dims, - dp_dim_index=dp_dim_index) - x = jnp.split(linear_1_out, len(activations), axis=-1) + linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,))) + + x = jnp.split(linear_1_out, len(activations), axis=-2) acts = [] for idx, act_fn in enumerate(activations): x_i = _convert_to_activation_function(act_fn)(x[idx]) acts.append(x_i) x = functools.reduce(operator.mul, acts) - x = jnp.asarray(x, jnp.bfloat16) - fp8_gemm_2_pkg = FP8GemmPackage(1, x, [kernel_2], - fp8_maxs[FP8Helper.NUM_META_PER_GEMM:], + x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) + + fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:], amax[FP8Helper.NUM_META_PER_GEMM:], scale[FP8Helper.NUM_META_PER_GEMM:], scale_inv[FP8Helper.NUM_META_PER_GEMM:]) - output = fp8_dot(fp8_gemm_2_pkg, - fwd_dtype, - bwd_dtype, - contracting_dims, - dp_dim_index=dp_dim_index) + output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,))) return output - def ref_func(x, ln_s, y, z, metas): + def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv): return jnp.mean( - fp8_ln_mlp_py(x, ln_s, y, z, *metas, *compute_type, activations=activations)) - - value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3))) - value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3))) - - ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, - ref_k2_grad) = value_n_grad_ref_func(a, s, k1, k2, fp8_meta) - - # calculate amax - primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, - primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta) - - # calculate scale by amax - fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta) - primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, - primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta) - - assert_allclose(primitive_out, ref_out, dtype=compute_type[0]) + ln_geglu_fp8_mlp_ref(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv)) + + value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7))) + value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7))) + + ref_fp8_max = init_fp8_max + ref_fp8_metas_amax = init_fp8_metas_amax + ref_fp8_metas_scale = init_fp8_metas_scale + ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv + + pri_fp8_max = init_fp8_max + pri_fp8_metas_amax = init_fp8_metas_amax + pri_fp8_metas_scale = init_fp8_metas_scale + pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv + + for _ in range(3): + ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_fp8_max, + ref_fp8_metas_amax, ref_fp8_metas_scale, + ref_fp8_metas_scale_inv) = value_n_grad_ref_func( + a, s, k1, k2, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale, + ref_fp8_metas_scale_inv) + + for _ in range(3): + primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, + primitive_k2_grad, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale, + pri_fp8_metas_scale_inv) = value_n_grad_primitive_func( + a, s, k1, k2, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale, + pri_fp8_metas_scale_inv) + + assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(jnp.asarray(primitive_a_grad, np.float32), jnp.asarray(ref_a_grad, np.float32), - dtype=compute_type[1]) + dtype=FP8Helper.BWD_DTYPE) assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), jnp.asarray(ref_k1_grad, np.float32), - dtype=compute_type[1]) + dtype=FP8Helper.BWD_DTYPE) assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), jnp.asarray(ref_k2_grad, np.float32), - dtype=compute_type[1]) + dtype=FP8Helper.BWD_DTYPE) assert_allclose(jnp.asarray(primitive_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32), - dtype=compute_type[1]) + dtype=FP8Helper.BWD_DTYPE) @pytest.fixture(name="random_inputs") @@ -411,10 +299,10 @@ class TestGatedGeLu: def ref_func(self, inputs): def jax_gated_gelu(x): - x = jnp.split(x, 2, axis=-1) + x = jnp.split(x, 2, axis=-2) acts = [jax.nn.gelu(x[0]), x[1]] x = functools.reduce(operator.mul, acts) - x = jnp.asarray(x, jnp.bfloat16) + x = jnp.asarray(jnp.squeeze(x, -2), jnp.bfloat16) return x func = jit(value_and_grad(lambda x: jnp.mean(jax_gated_gelu(x)))) @@ -438,17 +326,17 @@ def primitive_bwd(ctx, g): return (out,) primitive.defvjp(primitive_fwd, primitive_bwd) - func = jit(value_and_grad(lambda x: jnp.mean(primitive(x)))) + func = value_and_grad(lambda x: jnp.mean(primitive(x))) return func(inputs) - @pytest.mark.parametrize('shape', [(32, 64), (64, 256)]) + @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) def test_gated_gelu(self, random_inputs): x = random_inputs prim_out, prim_grad = self.prim_func(x) ref_out, ref_grad = self.ref_func(x) - assert_allclose(prim_out, ref_out, rtol=1e-2) - assert_allclose(prim_grad, ref_grad, rtol=1e-1, atol=1e-3) + assert_allclose(prim_out, ref_out, dtype=x.dtype) + assert_allclose(prim_grad, ref_grad, dtype=x.dtype) class TestGatedGeLuFP8(TestGatedGeLu): @@ -461,31 +349,30 @@ def prim_func(self, inputs): @jax.custom_vjp def primitive(x, y, z): - out = primitive_fwd(x, y, z) + out = primitive_fwd(x) return out - def primitive_fwd(x, y, z): # pylint: disable=unused-argument - out, _ = gated_gelu_fp8(x, amax, scale, scale_inv, DType.kFloat8E5M2) - out = dequantize(out, no_use, no_use, scale_inv, DType.kFloat8E5M2, DType.kBFloat16) + def primitive_fwd(x, y, z): + out, _ = gated_gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn) + out = dequantize(out, x.dtype, scale_inv) ctx = x return out, ctx def primitive_bwd(ctx, g): x = ctx dgelu, dgelu_trans, amax_out = dgated_gelu_cast_transpose(g, x, amax, scale, scale_inv, - DType.kFloat8E5M2) - dgelu = dequantize(dgelu, no_use, no_use, scale_inv, DType.kFloat8E5M2, DType.kFloat32) - dgelu_trans = dequantize(dgelu_trans, no_use, no_use, scale_inv, DType.kFloat8E5M2, - DType.kFloat32) + jnp.float8_e5m2, -1) + dgelu = dequantize(dgelu, x.dtype, scale_inv) + dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv) return dgelu, dgelu_trans, amax_out primitive.defvjp(primitive_fwd, primitive_bwd) - func = jit(value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2))) + func = value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2)) return func(inputs, no_use, no_use) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('shape', [(32, 64), (64, 256)]) + @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) def test_gated_gelu(self, random_inputs): self.amax = jnp.zeros(1, jnp.float32) self.scale = jnp.ones(1, jnp.float32) @@ -495,10 +382,12 @@ def test_gated_gelu(self, random_inputs): prim_out, (prim_grad, prim_grad_trans, amax) = self.prim_func(x) ref_out, ref_grad = self.ref_func(x) - assert_allclose(prim_out, ref_out, rtol=1e-2) + assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2) - assert_allclose(prim_grad, ref_grad, rtol=1e-1, atol=1e-3) - assert_allclose(prim_grad_trans, jnp.transpose(ref_grad), rtol=1e-1, atol=1e-3) + assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) + assert_allclose(prim_grad_trans, + jnp.transpose(ref_grad, (1, 2, 0)), + dtype=FP8Helper.BWD_DTYPE) class TestRMSNorm: @@ -529,14 +418,9 @@ def reference_rmsnorm(x, scale): primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale) reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale) - if dtype == jnp.float32: - assert_allclose(primitive_out, reference_out, rtol=1e-7) - assert_allclose(primitive_dx, reference_dx, rtol=1e-7) - assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7) - else: - assert_allclose(primitive_out, reference_out, rtol=1e-3) - assert_allclose(primitive_dx, reference_dx, rtol=1e-4, atol=5e-8) - assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-4, atol=5e-8) + assert_allclose(primitive_out, reference_out, dtype=dtype) + assert_allclose(primitive_dx, reference_dx, dtype=dtype) + assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype) class TestLayerNorm: @@ -587,13 +471,7 @@ def compute_loss(x): reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(x, scale, bias) - if dtype == jnp.float32: - assert_allclose(primitive_out, reference_out, rtol=1e-7) - assert_allclose(primitive_dx, reference_dx, rtol=1e-7) - assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7) - assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-7) - else: - assert_allclose(primitive_out, reference_out, rtol=1e-7) - assert_allclose(primitive_dx, reference_dx, rtol=1e-5, atol=1e-6) - assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-5, atol=3e-5) - assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-5, atol=3e-5) + assert_allclose(primitive_out, reference_out, dtype=dtype) + assert_allclose(primitive_dx, reference_dx, dtype=dtype) + assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype) + assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype) diff --git a/tests/jax/test_custom_call_shape.py b/tests/jax/test_custom_call_shape.py deleted file mode 100644 index 32d645b668..0000000000 --- a/tests/jax/test_custom_call_shape.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import pytest -import jax -import jax.numpy as jnp -from jax.core import ShapedArray - -from transformer_engine_jax import DType -from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype -from transformer_engine.jax.cpp_extensions import GemmPrimitive - -SHAPES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), - (16384, 1024, 1024)] -NAMED_SHAPES = [{}, { - "data": 4 -}, { - "data": 2 -}, { - "model": 4 -}, { - "model": 2 -}, { - "data": 4, - "model": 2 -}, { - "model": 4, - "data": 2 -}] -DTYPE = [DType.kFloat32, DType.kFloat16, DType.kBFloat16] -TRANSPOSE = [True, False] - - -@pytest.fixture(autouse=True, scope='function') -def clear_live_arrays(): - """ - Clear all live arrays to keep the resource clean - """ - yield - for arr in jax.live_arrays(): - arr.delete() - - -class TestGEMMShapeInfer: - - @staticmethod - def _joint_named_shape(ns1, ns2): - output_named_shape = {**ns1} - need_assert = False - for key in ns2: - if key in output_named_shape and output_named_shape[key] != ns2[key]: - need_assert = True - else: - output_named_shape[key] = ns2[key] - return output_named_shape, need_assert - - @staticmethod - def _get_shapes(m, n, k, transa, transb): - # te_gemm only support TN and col-major, then we have to reorder a, b shape - # to compute row-major matrices calculate in col-major algos. - a = (m, k) if transa else (k, m) - b = (k, n) if transb else (n, k) - out = (n, m) - return a, b, out - - @pytest.mark.parametrize('shapes', SHAPES) - @pytest.mark.parametrize('named_shape1', NAMED_SHAPES) - @pytest.mark.parametrize('named_shape2', NAMED_SHAPES) - @pytest.mark.parametrize('te_dtype', DTYPE) - @pytest.mark.parametrize('transa', TRANSPOSE) - @pytest.mark.parametrize('transb', TRANSPOSE) - def test_shape_infer(self, shapes, named_shape1, named_shape2, te_dtype, transa, transb): - a_shape, b_shape, out_shape = TestGEMMShapeInfer._get_shapes(*shapes, transa, transb) - dtype = te_dtype_to_jax_dtype(te_dtype) - mat_a = ShapedArray(a_shape, dtype, named_shape=named_shape1) - mat_b = ShapedArray(b_shape, dtype, named_shape=named_shape2) - - scale_inv_a = ShapedArray((3, 1), jnp.float32) - scale_inv_b = ShapedArray((3, 1), jnp.float32) - - ref_out_named_shape, need_assert = TestGEMMShapeInfer._joint_named_shape( - named_shape1, named_shape2) - ref_out = ShapedArray(out_shape, dtype, named_shape=ref_out_named_shape) - - try: - test_out = GemmPrimitive.abstract(mat_a, - mat_b, - scale_inv_a, - scale_inv_b, - A_dtype=te_dtype, - B_dtype=te_dtype, - D_dtype=te_dtype, - transa=transa, - transb=transb, - use_split_accumulator=False) - assert not need_assert - assert ref_out == test_out - except AssertionError as ae: - assert need_assert, f"{ae.args}" diff --git a/tests/jax/test_distributed_custom_ops.py b/tests/jax/test_distributed_custom_ops.py deleted file mode 100644 index 3943b42e77..0000000000 --- a/tests/jax/test_distributed_custom_ops.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -import pytest -import numpy as np -from functools import partial - -import jax -import jax.numpy as jnp -from jax import random -from jax.sharding import NamedSharding - -from utils import is_devices_enough -from distributed_configs_helper import * -from distributed_ops_helper import * -from transformer_engine.jax.sharding import global_shard_guard -from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType - -configs = DistributedConfigsHelper() # default device count is len(jax.devices()) -ops = DistributedOpsHelper() # default data type is jnp.float16 - -@pytest.mark.skipif(not is_devices_enough(configs.device_count), - reason='Insufficient number of GPUs, need at least 2.') -@pytest.mark.skipif(not ops.use_custom_partitioning(), - reason='TE/JAX version does not support sharding with ' + \ - 'jax.experimental.custom_partitioning.') -class TestCustomPartitioningOpsGenerator: - - @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref', - configs.layernorm_refs) - @pytest.mark.parametrize('zero_centered_gamma', [False, True]) - def test_layernorm(self, mesh_shape, mesh_names, sharding_type, collective_ref, - zero_centered_gamma): - epsilon = 1e-6 - - custom_func = partial(ops.custom_layernorm, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon, - sharding_type=sharding_type) - - reference_func = partial(ops.reference_layernorm, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) - - batch_size, _, num_heads, head_dim = ops.qkv_shape - hidden_size = num_heads*head_dim - input_shape = (batch_size, hidden_size) - other_shape = (hidden_size, ) - x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype) - gamma_ = jnp.ones(other_shape, dtype=ops.dtype) - beta_ = jnp.ones(other_shape, dtype=ops.dtype) - - x_spec, gamma_spec, beta_spec = ops.get_sharding_spec(mesh_names, sharding_type) - devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) - mesh = jax.sharding.Mesh(devices, mesh_names) - with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)): - x_ = jax.device_put(x_, NamedSharding(mesh, x_spec)) - gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec)) - beta_ = jax.device_put(beta_, NamedSharding(mesh, beta_spec)) - ops.compare_ops( - custom_func, reference_func, collective_ref, - x_, gamma_, beta_, grad_args=(0, 1, 2), dtype=ops.dtype, - in_shardings=[x_spec, gamma_spec, beta_spec], - out_shardings=(None, (x_spec, gamma_spec, beta_spec)) - ) - - @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref', - configs.layernorm_refs) - def test_rmsnorm(self, mesh_shape, mesh_names, sharding_type, collective_ref): - epsilon = 1e-6 - custom_func = partial(ops.custom_rmsnorm, epsilon=epsilon,sharding_type=sharding_type) - reference_func = partial(ops.reference_rmsnorm, epsilon=epsilon) - - batch_size, _, num_heads, head_dim = ops.qkv_shape - hidden_size = num_heads*head_dim - input_shape = (batch_size, hidden_size) - other_shape = (hidden_size, ) - x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype) - gamma_ = jnp.ones(other_shape, dtype=ops.dtype) - - x_spec, gamma_spec = ops.get_sharding_spec(mesh_names, sharding_type) - devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) - mesh = jax.sharding.Mesh(devices, mesh_names) - with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)): - x_ = jax.device_put(x_, NamedSharding(mesh, x_spec)) - gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec)) - ops.compare_ops( - custom_func, reference_func, collective_ref, - x_, gamma_, grad_args=(0, 1), dtype=ops.dtype, - in_shardings=[x_spec, gamma_spec], - out_shardings=(None, (x_spec, gamma_spec)) - ) - - @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref', - configs.softmax_refs) - @pytest.mark.parametrize('softmax_type', configs.softmax_types) - def test_softmax(self, mesh_shape, mesh_names, sharding_type, collective_ref, - softmax_type): - batch_size, seq_len, num_heads, head_dim = ops.qkv_shape - scale_factor = 1./jnp.sqrt(head_dim) - - custom_func = partial(ops.custom_softmax, - scale_factor=scale_factor, - softmax_type=softmax_type, - sharding_type=sharding_type) - reference_func = partial(ops.reference_softmax, - scale_factor=scale_factor, - softmax_type=softmax_type) - - input_size = (batch_size, num_heads, seq_len, seq_len) - x_ = random.normal(random.PRNGKey(1124), input_size, dtype=ops.dtype) - - pad_len = int(seq_len * ops.pad_ratio) - valid_len = seq_len - pad_len - tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype), - jnp.zeros((batch_size, pad_len), dtype=ops.dtype)), - axis=-1) - mask_ = ops.make_mask(tokens, tokens, AttnMaskType.PADDING_MASK) - - x_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type) - devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) - mesh = jax.sharding.Mesh(devices, mesh_names) - with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)): - x_ = jax.device_put(x_, NamedSharding(mesh, x_spec)) - mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec)) - ops.compare_ops( - custom_func, reference_func, collective_ref, - (0), x_, mask_, grad_args=(0), dtype=ops.dtype, - in_shardings=[x_spec, mask_spec], - out_shardings=(None, (x_spec)) - ) - - @pytest.mark.parametrize( - 'mesh_shape, mesh_names, sharding_type, attn_bias_type, collective_ref', - configs.self_attn_refs) - @pytest.mark.parametrize('attn_mask_type', configs.self_attn_mask_types) - def test_self_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref, - attn_bias_type, attn_mask_type, backend): - batch_size, seq_len, num_heads, head_dim = ops.qkv_shape - ops.check_fused_attn_inputs(seq_len, seq_len, head_dim, - ops.pad_ratio, ops.dropout_prob, - attn_bias_type, attn_mask_type, backend) - - dropout_rng = random.PRNGKey(91023051) - split_rng = random.split(dropout_rng, configs.device_count) - scale_factor = 1./jnp.sqrt(head_dim) - - custom_func = partial(ops.custom_self_fused_attn, - rng_key=split_rng, - dropout_prob=ops.dropout_prob, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scale_factor, - sharding_type=sharding_type) - reference_func = partial(ops.reference_self_fused_attn, - rng_key=dropout_rng, - dropout_prob=ops.dropout_prob, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scale_factor) - - key = random.PRNGKey(1124) - subkeys = random.split(key, 2) - - qkv_shape = (batch_size, seq_len, 3, num_heads, head_dim) - qkv_ = random.normal(subkeys[0], qkv_shape, dtype=ops.dtype) - bias_shape = (1, num_heads, seq_len, seq_len) - bias_ = random.normal(subkeys[1], bias_shape, dtype=ops.dtype) - - pad_len = int(seq_len * ops.pad_ratio) - valid_len = seq_len - pad_len - tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype), - jnp.zeros((batch_size, pad_len), dtype=ops.dtype)), - axis=-1) - mask_ = ops.make_mask(tokens, tokens, attn_mask_type) - - qkv_spec, bias_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type) - devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) - mesh = jax.sharding.Mesh(devices, mesh_names) - with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)): - qkv_ = jax.device_put(qkv_, NamedSharding(mesh, qkv_spec)) - bias_ = jax.device_put(bias_, NamedSharding(mesh, bias_spec)) - mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec)) - ops.compare_ops( - custom_func, reference_func, collective_ref, - qkv_, bias_, mask_, grad_args=(0, 1), dtype=ops.dtype, - in_shardings=[qkv_spec, bias_spec, mask_spec], - out_shardings=(None, (qkv_spec, bias_spec)) - ) - - @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref', - configs.cross_attn_refs) - @pytest.mark.parametrize('attn_mask_type', configs.cross_attn_mask_types) - def test_cross_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref, - attn_mask_type, backend): - batch_size, seq_len, num_heads, head_dim = ops.qkv_shape - ops.check_fused_attn_inputs(seq_len, seq_len, head_dim, - ops.pad_ratio, ops.dropout_prob, - AttnBiasType.NO_BIAS, attn_mask_type, backend) - - dropout_rng = random.PRNGKey(91023051) - split_rng = random.split(dropout_rng, configs.device_count) - scale_factor = 1./jnp.sqrt(head_dim) - - custom_func = partial(ops.custom_cross_fused_attn, - rng_key=split_rng, - dropout_prob=ops.dropout_prob, - attn_mask_type=attn_mask_type, - scaling_factor=scale_factor, - sharding_type=sharding_type) - reference_func = partial(ops.reference_cross_fused_attn, - rng_key=split_rng, - dropout_prob=ops.dropout_prob, - attn_mask_type=attn_mask_type, - scaling_factor=scale_factor) - - key = random.PRNGKey(1124) - subkeys = random.split(key, 2) - - q_shape = (batch_size, seq_len, num_heads, head_dim) - q_ = random.normal(subkeys[0], q_shape, dtype=ops.dtype) - kv_shape = (batch_size, seq_len, 2, num_heads, head_dim) - kv_ = random.normal(subkeys[1], kv_shape, dtype=ops.dtype) - - pad_len = int(seq_len * ops.pad_ratio) - valid_len = seq_len - pad_len - tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype), - jnp.zeros((batch_size, pad_len), dtype=ops.dtype)), - axis=-1) - mask_ = ops.make_mask(tokens, tokens, attn_mask_type) - - q_spec, kv_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type) - devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) - mesh = jax.sharding.Mesh(devices, mesh_names) - with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)): - q_ = jax.device_put(q_, NamedSharding(mesh, q_spec)) - kv_= jax.device_put(kv_, NamedSharding(mesh, kv_spec)) - mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec)) - ops.compare_ops( - custom_func, reference_func, collective_ref, - q_, kv_, mask_, grad_args=(0, 1), dtype=ops.dtype, - in_shardings=[q_spec, kv_spec, mask_spec], - out_shardings=(None, (q_spec, kv_spec)) - ) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py new file mode 100644 index 0000000000..a8d7448379 --- /dev/null +++ b/tests/jax/test_distributed_fused_attn.py @@ -0,0 +1,239 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np +from flax.linen import dot_product_attention +from jax import random +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from distributed_test_base import generate_configs, generate_collectives_count +from distributed_test_base import compare_ops +from utils import make_causal_mask, make_self_mask +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available +from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn +from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout + +DTYPES = [jnp.float16, jnp.bfloat16] + + +class TestDistributedSelfAttn: + + def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, + dtype): + jax_dtype = jax.dtypes.canonicalize_dtype(dtype) + _, seqlen, _, heads, _ = shape + is_dp_enabled = mesh_resource.dp_resource is not None + tp_size = 1 + if mesh_resource.tp_resource is not None: + idx = mesh_axes.index(mesh_resource.tp_resource) + tp_size = mesh_shape[idx] + + all_reduce_loss_bytes = 4 # 1 * FP32 + bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize + allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled) + # for loss and dbias + return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) + + def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype): + batch, seqlen, _, heads, _ = shape + + qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) + + bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \ + if with_bias else None + + mask = None + if attn_mask_type == AttnMaskType.PADDING_MASK: + mask = make_causal_mask(batch, seqlen) + elif attn_mask_type == AttnMaskType.CAUSAL_MASK: + mask = make_self_mask(batch, seqlen) + + qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, + None) + bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \ + if with_bias else None + mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ + if attn_mask_type != AttnMaskType.NO_MASK else None + + return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) + + @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) + @pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) + @pytest.mark.parametrize( + 'attn_bias_type', + [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS]) + @pytest.mark.parametrize('attn_mask_type', + [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) + @pytest.mark.parametrize('dtype', DTYPES) + def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, + attn_bias_type, attn_mask_type, dtype): + dropout_prob = 0.0 + is_training = True + scaling_factor = 1.0 + + _, seqlen, _, _, hidden = data_shape + + if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, + attn_mask_type, dropout_prob, seqlen, seqlen, hidden): + pytest.skip(f"No FusedAttn backwend found") + + def target_func(qkv, bias, mask): + return jnp.mean( + self_fused_attn(qkv, + bias, + mask, + None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training)) + + def ref_func(qkv, bias, mask): + query, key, value = jnp.split(qkv, [1, 2], axis=-3) + query = jnp.squeeze(query) + key = jnp.squeeze(key) + value = jnp.squeeze(value) + + output = dot_product_attention(query, + key, + value, + bias=bias, + mask=mask, + deterministic=is_training, + dropout_rate=dropout_prob, + dropout_rng=None, + dtype=jnp.float32) + + return jnp.mean(output).astype(dtype) + + with_bias = attn_bias_type != AttnBiasType.NO_BIAS + (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \ + self.generate_inputs(data_shape, mesh_resource, with_bias, + attn_mask_type, dtype) + collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes, + mesh_resource, with_bias, + data_shape, dtype) + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast(mesh_resource=mesh_resource): + qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) + bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \ + if bias is not None else bias + mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ + if mask is not None else mask + + grad_args = (0, 1) if with_bias else (0,) + out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) + + compare_ops(target_func, + ref_func, [qkv_, bias_, mask_], + collective_count_ref, + grad_args=grad_args, + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(qkv_pspec, bias_pspec, mask_pspec), + out_shardings=(None, out_grad_shardings)) + + +class TestDistributedCrossAttn: + + def generate_collectives_count_ref(self): + # for loss + all_reduce_loss_bytes = 4 # 1 * FP32 + return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) + + def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype): + batch, seqlen, heads, hidden = shape + + q = random.normal(random.PRNGKey(1124), shape, dtype=dtype) + kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype) + + mask = None + if attn_mask_type == AttnMaskType.PADDING_MASK: + mask = make_causal_mask(batch, seqlen) + elif attn_mask_type == AttnMaskType.CAUSAL_MASK: + mask = make_self_mask(batch, seqlen) + + q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) + + kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, + None) + mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ + if attn_mask_type != AttnMaskType.NO_MASK else None + + return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) + + @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) + @pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]]) + @pytest.mark.parametrize('attn_mask_type', + [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) + @pytest.mark.parametrize('dtype', DTYPES) + def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, + attn_mask_type, dtype): + attn_bias_type = AttnBiasType.NO_BIAS + dropout_prob = 0.0 + is_training = True + scaling_factor = 1.0 + + _, seqlen, _, hidden = data_shape + + if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type, + attn_mask_type, dropout_prob, seqlen, seqlen, hidden): + pytest.skip(f"No FusedAttn backwend found") + + def target_func(q, kv, mask): + return jnp.mean( + cross_fused_attn(q, + kv, + mask, + None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training)) + + def ref_func(query, kv, mask): + key, value = jnp.split(kv, [1], axis=-3) + query = jnp.squeeze(query) + key = jnp.squeeze(key) + value = jnp.squeeze(value) + + output = dot_product_attention(query, + key, + value, + bias=None, + mask=mask, + deterministic=is_training, + dropout_rate=dropout_prob, + dropout_rng=None, + dtype=jnp.float32) + + return jnp.mean(output).astype(dtype) + + (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \ + self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype) + collective_count_ref = self.generate_collectives_count_ref() + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast(mesh_resource=mesh_resource): + q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) + kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) + mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ + if mask is not None else mask + + compare_ops(target_func, + ref_func, [q_, kv_, mask_], + collective_count_ref, + grad_args=(0, 1), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(q_pspec, kv_pspec, mask_pspec), + out_shardings=(None, (q_pspec, kv_pspec))) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py new file mode 100644 index 0000000000..cfd24cd1f1 --- /dev/null +++ b/tests/jax/test_distributed_layernorm.py @@ -0,0 +1,130 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np +from jax import random +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from distributed_test_base import generate_configs, generate_collectives_count +from distributed_test_base import compare_ops +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.layernorm import layernorm + +DTYPES = [jnp.bfloat16, jnp.float32] + + +class TestDistributedLayernorm: + + def generate_inputs(self, shape, mesh_resource, dtype): + weight_shape = (shape[-1],) + + x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) + gamma = jnp.ones(weight_shape, dtype=dtype) + beta = jnp.ones(weight_shape, dtype=dtype) + + if len(shape) == 2: + x_pspec = PartitionSpec(mesh_resource.dp_resource, None) + elif len(shape) == 3: + x_pspec = PartitionSpec(mesh_resource.dp_resource, None, None) + else: + raise NotImplementedError + + g_pspec = b_pspec = PartitionSpec(None) + + return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) + + def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype): + jax_dtype = jax.dtypes.canonicalize_dtype(dtype) + is_dp_enabled = mesh_resource.dp_resource is not None + assert ln_type in ['layernorm', 'rmsnorm'] + all_reduce_loss_bytes = 4 # 1 * FP32 + # for loss, dgamma and dbeta + weight_count = 2 if ln_type == 'layernorm' else 1 + allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize + return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled), + allgather=0, + other=0) + + @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) + @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) + @pytest.mark.parametrize('dtype', DTYPES) + @pytest.mark.parametrize('zero_centered_gamma', [False, True]) + def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, + zero_centered_gamma): + epsilon = 1e-6 + ln_type = 'layernorm' + + def target_func(x, gamma, beta): + return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)) + + def ref_func(x, gamma, beta): + x_ = jnp.asarray(x, jnp.float32) + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon) + if zero_centered_gamma: + output = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype) + else: + output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype) + return jnp.mean(output) + + (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \ + self.generate_inputs(data_shape, mesh_resource, dtype) + collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, + data_shape, dtype) + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast(mesh_resource=mesh_resource): + x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) + gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) + beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) + + compare_ops(target_func, + ref_func, [x_, gamma_, beta_], + collective_count_ref, + grad_args=(0, 1, 2), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(x_pspec, g_pspec, b_pspec), + out_shardings=(None, (x_pspec, g_pspec, b_pspec))) + + @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) + @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) + @pytest.mark.parametrize('dtype', DTYPES) + def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype): + epsilon = 1e-6 + ln_type = 'rmsnorm' + + def target_func(x, gamma): + return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon)) + + def ref_func(x, gamma): + x = jnp.asarray(x, jnp.float32) + mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), dtype) + output = y * gamma + return jnp.mean(output) + + (x, gamma, _), (x_pspec, g_pspec, _) = \ + self.generate_inputs(data_shape, mesh_resource, dtype) + collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, + data_shape, dtype) + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast(mesh_resource=mesh_resource): + x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) + gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) + + compare_ops(target_func, + ref_func, [x_, gamma_], + collective_count_ref, + grad_args=(0, 1), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(x_pspec, g_pspec), + out_shardings=(None, (x_pspec, g_pspec))) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py new file mode 100644 index 0000000000..2fc5a037dd --- /dev/null +++ b/tests/jax/test_distributed_softmax.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np +from jax import random +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from distributed_test_base import generate_configs, generate_collectives_count +from distributed_test_base import compare_ops +from utils import make_causal_mask, make_self_mask +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.softmax import SoftmaxType, softmax + +DTYPES = [jnp.float16, jnp.bfloat16] + + +class TestDistributedSoftmax: + + def generate_collectives_count_ref(self): + # for loss + all_reduce_loss_bytes = 4 # 1 * FP32 + return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) + + def generate_inputs(self, shape, mesh_resource, softmax_type, dtype): + batch, _, sqelen, _ = shape + + x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) + if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED: + mask = make_causal_mask(batch, sqelen) + else: + mask = make_self_mask(batch, sqelen) + + x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, None, None) + mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) + + return (x, mask), (x_pspec, mask_pspec) + + @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) + @pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]]) + @pytest.mark.parametrize( + 'softmax_type', + [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED]) + @pytest.mark.parametrize('scale_factor', [1.0, 3.0]) + @pytest.mark.parametrize('dtype', DTYPES) + def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, + softmax_type, scale_factor, dtype): + + def target_func(x, mask): + return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type)) + + def ref_func(x, mask): + bias = None + if mask is not None: + bias = jax.lax.select(mask > 0, + jnp.full(mask.shape, -1e10).astype(dtype), + jnp.full(mask.shape, 0.).astype(dtype)) + if bias is not None: + x = x + bias.astype(dtype) + output = jax.nn.softmax(x * scale_factor) + return jnp.mean(output) + + (x, mask), (x_pspec, mask_pspec) = \ + self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype) + collective_count_ref = self.generate_collectives_count_ref() + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast(mesh_resource=mesh_resource): + x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) + mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) + + compare_ops(target_func, + ref_func, [x_, mask_], + collective_count_ref, + grad_args=(0,), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(x_pspec, mask_pspec), + out_shardings=(None, (x_pspec,))) diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 815aab6099..4910d969be 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -14,9 +14,7 @@ from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available, AmaxComputeAlgo -from transformer_engine.jax.sharding import infer_major_sharding_type -from transformer_engine.jax.sharding import MajorShardingType -from transformer_engine.jax.sharding import ShardingResource +from transformer_engine.jax.sharding import MeshResource, global_mesh_resource is_fp8_supported, reason = is_fp8_available() @@ -160,7 +158,6 @@ class TestFP8Functions(unittest.TestCase): def _check_defult_state(self): self.assertFalse(FP8Helper.is_fp8_enabled()) - self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE) def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.margin == test.margin) @@ -201,27 +198,20 @@ def test_fp8_autocast_with_sharding_resource(self): ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1) - # TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme - # srs = ( - # (ShardingResource(None, None), MajorShardingType.SINGLE), - # (ShardingResource('dp', None), MajorShardingType.DP), - # (ShardingResource(None, 'tp'), MajorShardingType.TP), - # (ShardingResource('dp', 'tp'), MajorShardingType.DPTP), - # ) - srs = ( - (ShardingResource(None, None), MajorShardingType.SINGLE), - (ShardingResource('dp', None), MajorShardingType.SINGLE), - (ShardingResource(None, 'tp'), MajorShardingType.SINGLE), - (ShardingResource('dp', 'tp'), MajorShardingType.SINGLE), + mesh_s = ( + (MeshResource(None, None)), + (MeshResource('dp', None)), + (MeshResource(None, 'tp')), + (MeshResource('dp', 'tp')), ) # TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme mesh_shape = (1, 1) devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) with jax.sharding.Mesh(devices, ('dp', 'tp')): - for sr, mst in srs: - with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr): + for sr in mesh_s: + with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr): self.assertTrue(FP8Helper.is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) - self.assertEqual(infer_major_sharding_type(), mst) + self.assertEqual(sr, global_mesh_resource()) self._check_defult_state() diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py index ea216ac514..22cfbd41c9 100644 --- a/tests/jax/test_sharding.py +++ b/tests/jax/test_sharding.py @@ -2,40 +2,10 @@ # # See LICENSE for license information. -import jax -import numpy as np import pytest -from utils import is_devices_enough from transformer_engine.jax.flax import extend_logical_axis_rules -from transformer_engine.jax.sharding import get_dot_sharding_meta -from transformer_engine.jax.sharding import get_elementwise_sharding_meta -from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta -from transformer_engine.jax.sharding import global_shard_guard -from transformer_engine.jax.sharding import infer_major_sharding_type -from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled -from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType - - -def _get_sharding_resource(mesh_names, sharding_type): - dp_r = None - tp_r = None - - if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW): - dp_r = mesh_names[0] - - if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW): - tp_r = mesh_names[0] - - if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW): - tp_r = mesh_names[1] - return ShardingResource(dp_r, tp_r) - - -DEVICE_COUNT = 4 -MESH_CONFIG = [((4,), ("dp",), ShardingType.DP), ((4,), ("tp",), ShardingType.TP_COL), - ((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL), - ((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)] +from transformer_engine.jax.sharding import global_shard_guard, MeshResource LOGICAL_RULES = [ [(('a1', None), ('a2', 'ma2')), False], @@ -44,18 +14,19 @@ def _get_sharding_resource(mesh_names, sharding_type): [(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True], [(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True], ] -SRS = [ - ShardingResource(), - ShardingResource('data', None), - ShardingResource(None, 'model'), - ShardingResource('data', 'model') + +MeshS = [ + MeshResource(), + MeshResource('data', None), + MeshResource(None, 'model'), + MeshResource('data', 'model') ] class TestShardingSideAPI: @pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES) - @pytest.mark.parametrize('sr', SRS) + @pytest.mark.parametrize('sr', MeshS) def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): with global_shard_guard(sr): try: @@ -65,270 +36,3 @@ def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): assert not need_assert except AssertionError as ae: assert need_assert, f"{ae.args}" - - -class TestGeneralFunc: - - @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) - @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough') - def test_infer_major_sharding_type( - self, - mesh_shape, # pylint: disable=unused-argument - mesh_names, - sharding_type): - devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) - with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): - with jax.sharding.Mesh(devices, mesh_names): - assert infer_major_sharding_type() is sharding_type.value[0] - - @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) - def test_is_dp_enabled( - self, - mesh_shape, # pylint: disable=unused-argument - mesh_names, # pylint: disable=unused-argument - sharding_type): - if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW): - assert is_dp_enabled(sharding_type.value[0]) - else: - assert not is_dp_enabled(sharding_type.value[0]) - - @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) - def test_is_tp_enabled( - self, - mesh_shape, # pylint: disable=unused-argument - mesh_names, # pylint: disable=unused-argument - sharding_type): - if sharding_type is ShardingType.DP: - assert not is_tp_enabled(sharding_type.value[0]) - else: - assert is_tp_enabled(sharding_type.value[0]) - - -class TestShardingMetaGenerator: - - BATCH_AXIS_NAME = 'batch' - MODEL_AXIS_NAME = 'model' - - @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) - @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough') - def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4): - - def stack_axes_meta(mapping): - return tuple(mapping for _ in range(num_of_fp8_meta)) - - def get_ref_sm(): - if sharding_type == ShardingType.DP: - return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}), - {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]}, (), - ()) - - if sharding_type == ShardingType.TP_COL: - return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}), - {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (), - ()) - - if sharding_type == ShardingType.TP_ROW: - return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}), - {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (), - ()) - - if sharding_type == ShardingType.DP_TP_COL: - return ShardingMeta( - stack_axes_meta({}), stack_axes_meta({}), { - TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0], - TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1] - }, (), ()) - - if sharding_type == ShardingType.DP_TP_ROW: - return ShardingMeta( - stack_axes_meta({}), stack_axes_meta({}), { - TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0], - TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1] - }, (), ()) - return None - - devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) - with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): - with jax.sharding.Mesh(devices, mesh_names): - test_sm = get_fp8_meta_sharding_meta( - sharding_type, - num_of_fp8_meta, - dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME, - tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME) - assert test_sm == get_ref_sm() - - @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) - @pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)), - ((128, 64, 512), (512, 256))]) - @pytest.mark.parametrize('batch_dim_of_a', [0, 1]) - @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough') - def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a): - model_dim_of_a = len(a_shape) - 1 - model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1 - contracting_dims = ((-1,), (0,)) - - def get_ref_sm(): - out_shape = (*a_shape[:min(contracting_dims[0])], - *b_shape[max(contracting_dims[1]) + 1:]) - if sharding_type == ShardingType.DP: - a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0], -1, - *a_shape[batch_dim_of_a + 1:]) - return ShardingMeta(({ - batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME - }, {}), ({ - batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME - }), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]}, - [a_new_shape, b_shape], [out_shape]) - - if sharding_type == ShardingType.TP_COL: - b_new_shape = (b_shape[0], mesh_shape[0], b_shape[1] // mesh_shape[0]) - return ShardingMeta(({}, { - 1: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), ({ - len(out_shape) - 1: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, - [a_shape, b_new_shape], [out_shape]) - - if sharding_type == ShardingType.TP_ROW: - a_new_shape = (*a_shape[:-1], mesh_shape[0], a_shape[-1] // mesh_shape[0]) - b_new_shape = (mesh_shape[0], b_shape[0] // mesh_shape[0], b_shape[1]) - return ShardingMeta(({ - len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME - }, { - 0: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), ({}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, - [a_new_shape, b_new_shape], [out_shape]) - - if sharding_type == ShardingType.DP_TP_COL: - a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0], - a_shape[batch_dim_of_a] // mesh_shape[0], - *a_shape[batch_dim_of_a + 1:]) - b_new_shape = (b_shape[0], mesh_shape[1], b_shape[1] // mesh_shape[1]) - return ShardingMeta( - ({ - batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME - }, { - 1: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), ({ - batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME, - len(out_shape): TestShardingMetaGenerator.MODEL_AXIS_NAME - }), { - TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0], - TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1] - }, [a_new_shape, b_new_shape], [out_shape]) - - if sharding_type == ShardingType.DP_TP_ROW: - a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0], - a_shape[batch_dim_of_a] // mesh_shape[0], - *a_shape[batch_dim_of_a + 1:-1], mesh_shape[1], - a_shape[-1] // mesh_shape[1]) - b_new_shape = (mesh_shape[1], b_shape[0] // mesh_shape[1], b_shape[1]) - return ShardingMeta( - ({ - batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME, - len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME - }, { - 0: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), ({ - batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME - }), { - TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0], - TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1] - }, [a_new_shape, b_new_shape], [out_shape]) - return None - - devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) - with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): - with jax.sharding.Mesh(devices, mesh_names): - test_sm = get_dot_sharding_meta( - sharding_type, - a_shape, - b_shape, - batch_dim_of_a, - model_dim_of_a, - model_dim_of_b, - contracting_dims, - dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME, - tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME) - assert test_sm == get_ref_sm() - - @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) - @pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)]) - @pytest.mark.parametrize('other_shape', [(256,), (512,)]) - @pytest.mark.parametrize('batch_dim', [0, 1]) - @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough') - def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape, - batch_dim): - - def get_ref_sm(): - need_assert = True - ref_sharding_meta = None - if input_shape[-1] != other_shape[0]: - need_assert = True - ref_sharding_meta = None - elif sharding_type is (ShardingType.DP_TP_COL, ShardingType.DP): - need_assert = False - input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1, - *input_shape[batch_dim + 1:]) - ref_sharding_meta = ShardingMeta(({ - batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME - }, {}), ({ - batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME - }), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]}, - [input_new_shape, other_shape], [input_shape]) - elif sharding_type is ShardingType.TP_COL: - need_assert = False - ref_sharding_meta = ShardingMeta(({}, {}), ({}), {}, [input_shape, other_shape], - [input_shape]) - elif sharding_type is ShardingType.TP_ROW: - need_assert = False - input_new_shape = (*input_shape[:-1], mesh_shape[0], -1) - other_new_shape = (mesh_shape[0], -1) - - ref_sharding_meta = ShardingMeta(({ - len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME - }, { - 0: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), ({ - len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, - [input_new_shape, other_new_shape], [input_shape]) - elif sharding_type is ShardingType.DP_TP_ROW: - need_assert = False - input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1, - *input_shape[batch_dim + 1:-1], mesh_shape[1], - input_shape[-1] // mesh_shape[1]) - other_new_shape = (mesh_shape[0], -1) - - ref_sharding_meta = ShardingMeta( - ({ - batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME, - len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME - }, { - 0: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), ({ - batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME, - len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME - }), { - TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0], - TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1] - }, [input_new_shape, other_new_shape], [input_shape]) - - return ref_sharding_meta, need_assert - - devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) - with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): - with jax.sharding.Mesh(devices, mesh_names): - ref_sm, need_assert = get_ref_sm() - try: - test_sm = get_elementwise_sharding_meta( - sharding_type, - input_shape, - other_shape, - batch_dim, - dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME, - tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME) - assert not need_assert - assert test_sm == ref_sm - except (NotImplementedError, AssertionError) as e: - assert need_assert, f"{e.args}" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 118b5c3b00..f065c53c2b 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -26,6 +26,7 @@ lax.Precision]] Initializer = Callable[[PRNGKey, Shape, DType], Array] + def is_devices_enough(required): return len(jax.devices()) >= required @@ -1010,6 +1011,24 @@ def __call__(self, return z +def make_causal_mask(batch, seqlen, dtype=jnp.uint8): + shape = (batch, seqlen) + idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape) + + mask = jnp.greater_equal(jnp.expand_dims(idxs, axis=-1), jnp.expand_dims(idxs, axis=-2)) + mask = jnp.expand_dims(mask, axis=-3) + mask = 1 - mask + return mask.astype(dtype) + + +def make_self_mask(batch, seqlen, dtype=jnp.uint8): + shape = (batch, seqlen) + mask = jnp.ones((*shape, shape[-1])) + mask = jnp.expand_dims(mask, axis=-3) + mask = 1 - mask + return mask.astype(dtype) + + def assert_allclose( actual: Array, desired: Array, @@ -1092,7 +1111,7 @@ def dtype_tols( # Estimate floating-point error finfo = jnp.finfo(dtype) - eps_relaxed = math.pow(finfo.eps, 2/3) + eps_relaxed = math.pow(finfo.eps, 2 / 3) with jax.default_device(jax.devices("cpu")[0]): if isinstance(reference_value, (float, int)): reference_value = jnp.array(reference_value, dtype=dtype) diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 0268cbfdd8..b339cb69ac 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -5,10 +5,30 @@ from . import flax from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling +from .fp8 import NVTE_FP8_COLLECTION_NAME +from .sharding import MeshResource from .sharding import MajorShardingType, ShardingResource, ShardingType +from ..common.utils import deprecate_wrapper +from ..common.utils import DeprecatedEnum + +MajorShardingType = DeprecatedEnum(MajorShardingType, + "MajorShardingType is deprecating in the near feature.") +ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.") +ShardingResource = deprecate_wrapper( + ShardingResource, + "ShardingResource is renamed to MeshResource, and will be removed in the near feature.") __all__ = [ - 'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling', - 'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis', + 'NVTE_FP8_COLLECTION_NAME', + 'fp8_autocast', + 'update_collections', + 'update_fp8_metas', + 'get_delayed_scaling', + 'MeshResource', + 'MajorShardingType', + 'ShardingResource', + 'ShardingType', + 'flax', + 'praxis', ] diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index a876fe5315..adaf6ca3e0 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -14,9 +14,11 @@ import jax.numpy as jnp from jax.lib import xla_client from jax import core, dtypes -from jax.core import ShapedArray from jax.interpreters import xla, mlir +from jax.experimental.custom_partitioning import custom_partitioning from jax.interpreters.mlir import ir, dtype_to_ir_type +from jax.sharding import PartitionSpec, NamedSharding +from jax._src.interpreters import batching try: from jaxlib.hlo_helpers import custom_call @@ -32,6 +34,11 @@ from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_Fused_Attn_Backend +from .sharding import all_reduce_max_along_all_axes_except_PP +from .sharding import all_reduce_sum_along_dp_fsdp +from .sharding import get_all_mesh_axes, num_of_devices +from .sharding import get_padded_spec as te_get_padded_spec + for _name, _value in transformer_engine_jax.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") @@ -41,17 +48,21 @@ def te_dtype_to_jax_dtype(te_dtype): convert TE dtype to jax dtype """ assert isinstance(te_dtype, TEDType) - if te_dtype == TEDType.kFloat32: - return jnp.float32 - if te_dtype == TEDType.kFloat16: - return jnp.float16 - if te_dtype == TEDType.kBFloat16: - return jnp.bfloat16 - if te_dtype == TEDType.kInt32: - return jnp.int32 - if te_dtype == TEDType.kInt64: - return jnp.int64 - return jnp.int8 + + converter = { + TEDType.kFloat32: jnp.float32, + TEDType.kFloat16: jnp.float16, + TEDType.kBFloat16: jnp.bfloat16, + TEDType.kInt32: jnp.int32, + TEDType.kInt64: jnp.int64, + TEDType.kFloat8E4M3: jnp.float8_e4m3fn, + TEDType.kFloat8E5M2: jnp.float8_e5m2, + } + + if te_dtype not in converter: + raise ValueError(f"Unsupported {te_dtype=}") + + return converter.get(te_dtype) def te_dtype_to_ir_dtype(te_dtype): @@ -61,65 +72,53 @@ def te_dtype_to_ir_dtype(te_dtype): return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype))) +def jax_dtype_to_ir_dtype(jax_dtype): + """ + convert Jax dtype to MLIR dtype + """ + return dtype_to_ir_type(np.dtype(jax_dtype)) + + def jax_dtype_to_te_dtype(jax_dtype): """ convert jax dtype to TE dtype """ jax_dtype = dtypes.canonicalize_dtype(jax_dtype) - if jax_dtype == jnp.float32: - return TEDType.kFloat32 - if jax_dtype == jnp.float16: - return TEDType.kFloat16 - if jax_dtype == jnp.bfloat16: - return TEDType.kBFloat16 - raise ValueError(f"Not support the {jax_dtype=}") + converter = { + jnp.float32.dtype: TEDType.kFloat32, + jnp.float16.dtype: TEDType.kFloat16, + jnp.bfloat16.dtype: TEDType.kBFloat16, + jnp.int32.dtype: TEDType.kInt32, + jnp.int64.dtype: TEDType.kInt64, + jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, + jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, + } -@dataclass(frozen=True) -class FusedAttnHelper: - """ - Helper for the fused attention backend - """ + if jax_dtype not in converter: + raise ValueError(f"Unsupported {jax_dtype=}") - q_type: jnp.dtype - kv_type: jnp.dtype - qkv_layout: NVTE_QKV_Layout - attn_bias_type: NVTE_Bias_Type - attn_mask_type: NVTE_Mask_Type - dropout_probability: float - max_seqlen_q: int - max_seqlen_kv: int - head_dim: int + return converter.get(jax_dtype) - def is_fused_attn_kernel_available(self): - """Check if there is available fused attention kernel""" - return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend - def get_fused_attn_backend(self): - """Get the fused attention kernel backend""" - return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type), - jax_dtype_to_te_dtype(self.kv_type), - self.qkv_layout, self.attn_bias_type, - self.attn_mask_type, - self.dropout_probability, - self.max_seqlen_q, self.max_seqlen_kv, - self.head_dim) +def get_padded_spec(arg_info): + """ + Get padded spec for partitioning from arguments' information + """ + if arg_info.sharding is None: + return te_get_padded_spec(None, arg_info.ndim) + ndim, spec = arg_info.ndim, arg_info.sharding.spec + return te_get_padded_spec(spec, ndim) -def merge_named_shape(base, new): +def _check_valid_batch_dims(bdims): """ - merge named shape(ie, dict), no key conflict + Assert out non-supported bath dims """ - output_named_shape = {**base} - for key in new: - if key in output_named_shape: - assert output_named_shape[key] == new[key], \ - f"The value of named shape with a same name should be equal between" \ - f" base and new in merge_named_shape, but got base[{key}]=" \ - f"{output_named_shape[key]} and {new[key]=}" - else: - output_named_shape[key] = new[key] - return output_named_shape + for dim in bdims: + assert dim in [0, None], \ + "Currently only support batch_dim in [0, None], " \ + f"but got {dim=}" class BasePrimitive(metaclass=ABCMeta): @@ -143,17 +142,65 @@ def lowering(): """ return NotImplemented + @staticmethod + @abstractmethod + def impl(): + """ + to describe implementation + """ + return NotImplemented + + @staticmethod + @abstractmethod + def batcher(): + """ + to describe batch rules for vmap + """ + return NotImplemented + + @staticmethod + @abstractmethod + def infer_sharding_from_operands(): + """ + to describe infer_sharding_from_operands for custom_partitioning + """ + return NotImplemented + + @staticmethod + @abstractmethod + def partition(): + """ + to describe partition for custom_partitioning + """ + return NotImplemented + def register_primitive(cls): """ register jax primitive """ - p = core.Primitive(cls.name) - p.multiple_results = cls.multiple_results - p.def_impl(partial(xla.apply_primitive, p)) - p.def_abstract_eval(cls.abstract) - mlir.register_lowering(p, cls.lowering, platform='cuda') - return p + + def name_of_wrapper_p(): + return cls.name + "_wrapper" + + inner_p = core.Primitive(cls.name) + inner_p.multiple_results = cls.multiple_results + inner_p.def_impl(partial(xla.apply_primitive, inner_p)) + inner_p.def_abstract_eval(cls.abstract) + mlir.register_lowering(inner_p, cls.lowering, platform='cuda') + cls.inner_primitive = inner_p + + outer_p = core.Primitive(name_of_wrapper_p()) + outer_p.multiple_results = cls.multiple_results + outer_p.def_impl(cls.impl) + outer_p.def_abstract_eval(cls.abstract) + batching.primitive_batchers[outer_p] = cls.batcher + outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) + outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands, + partition=cls.partition) + mlir.register_lowering(outer_p, + mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)) + cls.outer_primitive = outer_p @dataclass @@ -226,2254 +273,3408 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): return out -class TransposePrimitive(BasePrimitive): +class LayerNormFwdPrimitive(BasePrimitive): """ - Transpose Primitive + Layer Normalization Forward Primitive """ - name = "te_transpose" - multiple_results = False + name = "te_layernorm_forward" + multiple_results = True + impl_static_args = (3, 4) # zero_centered_gamma, epsilon + inner_primitive = None + outer_primitive = None @staticmethod - def abstract(inputs, *, dtype): + def abstract(x_aval, gamma_aval, beta_aval, **kwargs): # pylint: disable=unused-argument """ - _transpose abstract + LayerNorm fwd abstract """ - in_dtype = dtypes.canonicalize_dtype(inputs.dtype) - out_dtype = te_dtype_to_jax_dtype(dtype) + x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + mu_rsigama_dtype = jnp.float32 - assert len(inputs.shape) == 2 - assert isinstance(dtype, TEDType) - assert in_dtype == out_dtype + out_aval = core.raise_to_shaped(x_aval) + mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) - return ShapedArray((inputs.shape[1], inputs.shape[0]), - in_dtype, - named_shape=inputs.named_shape) + assert gamma_aval.size == beta_aval.size + hidden_size = gamma_aval.size + assert x_aval.size % hidden_size == 0 + + return out_aval, mu_aval, rsigma_aval @staticmethod - def lowering(ctx, inputs, *, dtype): + def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): """ - _transpose cuda lowering + LayerNorm fwd lowering rules """ + x_aval, gamma_aval, beta_aval = ctx.avals_in + assert gamma_aval.dtype == beta_aval.dtype + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + b_type = ir.RankedTensorType(beta.type) + b_shape = b_type.shape - in_aval = ctx.avals_in[0] - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16, jnp.int8] - - ir_in_type = ir.RankedTensorType(inputs.type) - ir_in_shape = ir_in_type.shape - ir_out_dtype = te_dtype_to_ir_dtype(dtype) - - out_types = [ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype)] - operands = [inputs] - operand_shapes = [ir_in_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - assert len(ir_in_shape) == 2 - opaque = transformer_engine_jax.pack_common_descriptor(ir_in_shape, dtype, dtype) - - out = custom_caller(TransposePrimitive.name, args, opaque, False) - - return [out] + assert g_type == b_type + assert g_shape == b_shape + # Output shape is same as the input shape, but the output type is same as the weight type. + # See ln_api.cpp + output_type = g_type.element_type + ir_mu_dtype = ir.F32Type.get() + ir_rsigma_dtype = ir.F32Type.get() -_transpose_p = register_primitive(TransposePrimitive) + out_shape = x_shape + hidden_size = reduce(operator.mul, g_shape) + batch_shape = out_shape[:-1] + batch_size = reduce(operator.mul, x_shape) // hidden_size + out_types = [ + ir.RankedTensorType.get(out_shape, output_type), + ir.RankedTensorType.get(batch_shape, ir_mu_dtype), + ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), + ] + operands = [x, gamma, beta] + operand_shapes = [x_shape, g_shape, b_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) -def transpose(inputs: jnp.ndarray, dtype: TEDType) -> jnp.ndarray: - """ - transpose wrapper - Assume input has two dimension shape - """ - return _transpose_p.bind(inputs, dtype=dtype) + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + zero_centered_gamma, + epsilon, + ) + out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False) -class CastTransposePrimitive(BasePrimitive): - """ - Cast Transpose Primitive - """ - name = "te_cast_transpose" - multiple_results = True + return out @staticmethod - def abstract(inputs, amax, scale, scale_inv, *, out_dtype): + def impl(x, gamma, beta, zero_centered_gamma, epsilon): """ - te_cast_transpose_p abstract + to describe implementation """ - dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert len(inputs.shape) == 2 - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax.dtype == jnp.float32 - assert scale.dtype == jnp.float32 - assert scale_inv.dtype == jnp.float32 - out_dtype = te_dtype_to_jax_dtype(out_dtype) - # input_cast, input_cast_trans, amax - return (ShapedArray((inputs.shape[0], inputs.shape[1]), - out_dtype, - named_shape=inputs.named_shape), - ShapedArray((inputs.shape[1], inputs.shape[0]), - out_dtype, - named_shape=inputs.named_shape), - ShapedArray((1,), amax.dtype, named_shape=amax.named_shape)) - - @staticmethod - def lowering(ctx, inputs, amax, scale, scale_inv, *, out_dtype): + assert LayerNormFwdPrimitive.inner_primitive is not None + out, mu, rsigma = LayerNormFwdPrimitive.inner_primitive.bind( + x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) + return out, mu, rsigma + + @staticmethod + def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon): """ - te_cast_transpose_p lowering rules + to describe batch rules for vmap """ - in_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_in_type = ir.RankedTensorType(inputs.type) - ir_in_shape = ir_in_type.shape - ir_out_dtype = te_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape + _check_valid_batch_dims(batch_dims) + assert LayerNormFwdPrimitive.outer_primitive is not None + x, gamma, beta = batched_args + x_bdim, _, _ = batch_dims - out_types = [ - ir.RankedTensorType.get([ir_in_shape[0], ir_in_shape[1]], ir_out_dtype), - ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [inputs, amax, scale, scale_inv] - operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + out_bdims = x_bdim, x_bdim, x_bdim + return LayerNormFwdPrimitive.outer_primitive.bind(x, + gamma, + beta, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon), out_bdims - assert len(ir_in_shape) == 2 - opaque = transformer_engine_jax.pack_common_descriptor(ir_in_shape, - jax_dtype_to_te_dtype(in_aval.dtype), - out_dtype) + @staticmethod + def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): + del zero_centered_gamma, epsilon, result_infos + x_spec = get_padded_spec(arg_infos[0]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) + return (out_sharding, mu_sharding, rsigma_sharding) - out = custom_caller(CastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={1: 2}) + @staticmethod + def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): + del result_infos + x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) + b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec)) + out_sharding = x_sharding + mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) - return out + arg_shardings = (x_sharding, g_sharding, b_sharding) + out_shardings = (out_sharding, mu_sharding, rsigma_sharding) + impl = partial(LayerNormFwdPrimitive.impl, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) + return mesh, impl, out_shardings, arg_shardings -_cast_transpose_p = register_primitive(CastTransposePrimitive) +register_primitive(LayerNormFwdPrimitive) -def cast_transpose(inputs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, - scale_inv: jnp.ndarray, - out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, + epsilon: float): """ - cast transpose wrapper - Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` + Wrapper for TE layernorm fwd """ - return _cast_transpose_p.bind(inputs, amax, scale, scale_inv, out_dtype=out_dtype) + return LayerNormFwdPrimitive.outer_primitive.bind(x, + gamma, + beta, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) -class GatedGeluPrimitive(BasePrimitive): +class LayerNormBwdPrimitive(BasePrimitive): """ - Gated Gelu Primitive + Layer Normalization Backward Primitive """ - name = "te_gated_gelu" - multiple_results = False + name = "te_layernorm_backward" + multiple_results = True + impl_static_args = (5, 6) # zero_centered_gamma, epsilon + inner_primitive = None + outer_primitive = None @staticmethod - def abstract(inputs): + def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument """ - te_gated_gelu_p abstract + Layernorm bwd abstract """ - dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - inputs_shape = inputs.shape - hidden_size = inputs_shape[-1] - # In Transformer, batch_shape = (batch, seqlen, ) - batch_shapes = inputs_shape[:-1] - assert hidden_size % 2 == 0 - inputs_shape = inputs.shape - out_shape = (batch_shapes) + (hidden_size // 2,) + w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) + mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype) + rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) + + assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype + assert dz_aval.shape == x_aval.shape + assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] + assert mu_dtype == rsigma_dtype == jnp.float32 - return ShapedArray(out_shape, dtype, named_shape=inputs.named_shape) + dx_aval = core.raise_to_shaped(dz_aval) + dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) + return dx_aval, dgamma_aval, dbeta_aval @staticmethod - def lowering(ctx, inputs): + def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): """ - te_gated_gelu_p lowering rules + Layernorm bwd lowering rules """ - (in_aval,) = ctx.avals_in - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - ir_in_type = ir.RankedTensorType(inputs.type) - ir_in_shape = ir_in_type.shape - out_shape = ir_in_shape[:-1] + [ir_in_shape[-1] // 2] + _, x_aval, _, _, gamma_aval = ctx.avals_in + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + b_type = ir.RankedTensorType(gamma.type) + b_shape = b_type.shape + assert g_type == b_type + assert g_shape == b_shape + + dz_shape = ir.RankedTensorType(dz.type).shape + mu_shape = ir.RankedTensorType(mu.type).shape + rsigma_shape = ir.RankedTensorType(rsigma.type).shape + + hidden_size = reduce(operator.mul, g_shape) + batch_size = reduce(operator.mul, x_shape) // hidden_size out_types = [ - ir.RankedTensorType.get(out_shape, ir_in_type.element_type), + ir.RankedTensorType.get(x_shape, x_type.element_type), + ir.RankedTensorType.get(g_shape, g_type.element_type), + ir.RankedTensorType.get(b_shape, b_type.element_type), ] - operands = [inputs] - operand_shapes = [ir_in_shape] + operands = [dz, mu, rsigma, x, gamma] + operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - hidden_size = ir_in_shape[-1] - # In Transformer, batch_size = batch x seqlen - batch_size = reduce(operator.mul, ir_in_shape[:-1]) - in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size // 2), - in_dtype, in_dtype) + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + zero_centered_gamma, + epsilon, + ) - out = custom_caller(GatedGeluPrimitive.name, args, opaque, False) + out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False) - return [out] + return out + @staticmethod + def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon): + assert LayerNormBwdPrimitive.inner_primitive is not None + dx, dgamma, dbeta = LayerNormBwdPrimitive.inner_primitive.bind( + dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) + return dx, dgamma, dbeta -_gated_gelu_p = register_primitive(GatedGeluPrimitive) + @staticmethod + def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon): + _check_valid_batch_dims(batch_dims) + assert LayerNormBwdPrimitive.outer_primitive is not None + dz, x, mu, rsigma, gamma = batched_args + _, x_bdim, _, _, gamma_bdim = batch_dims + + out_bdims = x_bdim, gamma_bdim, gamma_bdim + return LayerNormBwdPrimitive.outer_primitive.bind(dz, + x, + mu, + rsigma, + gamma, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon), out_bdims + @staticmethod + def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): + del zero_centered_gamma, epsilon, result_infos + x_spec = get_padded_spec(arg_infos[1]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + g_b_spec = get_padded_spec(arg_infos[4]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec)) + return dx_sharding, dgamma_sharding, dbeta_sharding -def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray: + @staticmethod + def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[1]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + g_b_spec = get_padded_spec(arg_infos[4]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec)) + out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding + x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. + mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2 + arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec))) + + def sharded_impl(dz, x, mu, rsigma, gamma): + local_dx, local_dgamma, local_dbeta = \ + LayerNormBwdPrimitive.impl(dz, x, mu, rsigma, gamma, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) + global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) + global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta) + return local_dx, global_dgamma, global_dbeta + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(LayerNormBwdPrimitive) + + +def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray, + gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float): """ - gated gelu wrapper - Return FP8(geglu(inputs)) - Assume inputs has two dimensions shape and the memory layout is (N, 2, H) + Wrapper for TE layernorm bwd """ - return _gated_gelu_p.bind(inputs) + return LayerNormBwdPrimitive.outer_primitive.bind(dz, + x, + mu, + rsigma, + gamma, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) -class GatedGeluFp8Primitive(BasePrimitive): +class RmsNormFwdPrimitive(BasePrimitive): """ - Gated Gelu FP8 Primitive + RMS Normalization Forward Primitive """ - name = "te_gated_gelu_fp8" + name = "te_rmsnorm_forward" multiple_results = True + impl_static_args = (2,) # epsilon + inner_primitive = None + outer_primitive = None @staticmethod - def abstract(inputs, amax, scale, scale_inv, *, out_dtype): + def abstract(x_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument """ - te_gated_gelu_p abstract + RMSNorm fwd abstract """ - dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax.dtype == jnp.float32 - assert scale.dtype == jnp.float32 - assert scale_inv.dtype == jnp.float32 - out_dtype = te_dtype_to_jax_dtype(out_dtype) + x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + rsigama_dtype = jnp.float32 + + out_aval = core.raise_to_shaped(x_aval) + rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) - assert len(inputs.shape) == 2 - hidden_size = inputs.shape[1] - batch_size = inputs.shape[0] # In Transformer, batch_size = batch x seqlen + hidden_size = gamma_aval.size + assert x_aval.size % hidden_size == 0 - # input_cast, input_cast_trans, amax - return (ShapedArray((batch_size, hidden_size // 2), - out_dtype, - named_shape=inputs.named_shape), - ShapedArray((1,), amax.dtype, named_shape=amax.named_shape)) + return out_aval, rsigma_aval @staticmethod - def lowering(ctx, inputs, amax, scale, scale_inv, *, out_dtype): + def lowering(ctx, x, gamma, *, epsilon): """ - te_gated_gelu_p lowering rules + RMSNorm fwd lowering rules """ - in_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_in_type = ir.RankedTensorType(inputs.type) - ir_in_shape = ir_in_type.shape - ir_out_dtype = te_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape + x_aval, gamma_aval = ctx.avals_in + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + rsigma_element_type = ir.F32Type.get() + + out_shape = x_shape + hidden_size = reduce(operator.mul, g_shape) + batch_shape = out_shape[:-1] + batch_size = reduce(operator.mul, x_shape) // hidden_size - hidden_size = ir_in_shape[1] - batch_size = ir_in_shape[0] # In Transformer, batch_size = batch x seqlen out_types = [ - ir.RankedTensorType.get([batch_size, hidden_size // 2], ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(out_shape, x_type.element_type), + ir.RankedTensorType.get(batch_shape, rsigma_element_type), ] - operands = [inputs, amax, scale, scale_inv] - operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + operands = [x, gamma] + operand_shapes = [x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_descriptor( - (ir_in_shape[0], ir_in_shape[1] // 2), jax_dtype_to_te_dtype(in_aval.dtype), out_dtype) + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + False, # RMSNorm doesn't support zero_centered_gamma + epsilon, + ) - out = custom_caller(GatedGeluFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={1: 1}) + out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False) return out - -_gated_gelu_fp8_p = register_primitive(GatedGeluFp8Primitive) - - -def gated_gelu_fp8(inputs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, - scale_inv: jnp.ndarray, - out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - cast gated gelu wrapper - Return FP8(geglu(inputs)) - Assume inputs has two dimensions shape and the memory layout is (N, 2, H) - """ - return _gated_gelu_fp8_p.bind(inputs, amax, scale, scale_inv, out_dtype=out_dtype) - - -class DgatedGeluPrimitive(BasePrimitive): - """ - Dgated Gelu Primitive - """ - name = "te_dgated_gelu" - multiple_results = False - @staticmethod - def abstract(inputs, gelu_inputs): + def impl(x, gamma, epsilon): """ - te_dgated_gelu_p abstract + to describe implementation """ - dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert gelu_inputs.dtype == dtype - for axis in range(len(inputs.shape) - 1): - assert inputs.shape[axis] == gelu_inputs.shape[axis] - - i_hidden_size = inputs.shape[-1] - g_hidden_szie = gelu_inputs.shape[-1] - assert i_hidden_size * 2 == g_hidden_szie - return ShapedArray(gelu_inputs.shape, dtype, named_shape=inputs.named_shape) + assert RmsNormFwdPrimitive.inner_primitive is not None + out, rsigma = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) + return out, rsigma @staticmethod - def lowering(ctx, inputs, gelu_inputs): + def batcher(batched_args, batch_dims, *, epsilon): """ - te_dgated_gelu_p lowering rules + to describe batch rules for vmap """ - in_aval, gi_aval = ctx.avals_in - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert gi_aval.dtype == in_aval.dtype - ir_in_type = ir.RankedTensorType(inputs.type) - ir_in_shape = ir_in_type.shape - gi_type = ir.RankedTensorType(gelu_inputs.type) - gi_shape = gi_type.shape - for axis in range(len(ir_in_shape) - 1): - assert ir_in_shape[axis] == gi_shape[axis] - - # In Transformer, batch_size = batch x seqlen - ir_batch_szie = reduce(operator.mul, ir_in_shape[:-1]) - i_hidden_size = ir_in_shape[-1] - g_hidden_szie = gi_shape[-1] - assert i_hidden_size * 2 == g_hidden_szie - out_dtype = ir_in_type.element_type - out_shape = gi_shape + _check_valid_batch_dims(batch_dims) + assert RmsNormFwdPrimitive.outer_primitive is not None + x, gamma = batched_args + x_bdim, _ = batch_dims - out_types = [ - ir.RankedTensorType.get(out_shape, out_dtype), - ] - operands = [inputs, gelu_inputs] - operand_shapes = [ir_in_shape, gi_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_szie, i_hidden_size), - in_dtype, in_dtype) + out_bdims = x_bdim, x_bdim + return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon), out_bdims - out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False) + @staticmethod + def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos): + del epsilon, result_infos + x_spec = get_padded_spec(arg_infos[0]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) + return (out_sharding, rsigma_sharding) - return [out] + @staticmethod + def partition(epsilon, mesh, arg_infos, result_infos): + del result_infos + x_spec, g_spec = map(get_padded_spec, arg_infos) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) + out_sharding = x_sharding + rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) + arg_shardings = (x_sharding, g_sharding) + out_shardings = (out_sharding, rsigma_sharding) + impl = partial(RmsNormFwdPrimitive.impl, epsilon=epsilon) + return mesh, impl, out_shardings, arg_shardings -_dgated_gelu_p = register_primitive(DgatedGeluPrimitive) +register_primitive(RmsNormFwdPrimitive) -def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray: +def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float): """ - dgated_gelu fusion wrapper - Return dgeglu(inputs) + Wrapper for TE rmsnorm fwd """ - return _dgated_gelu_p.bind(inputs, gelu_inputs) + return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon) -class DgatedGeluCastTransposePrimitive(BasePrimitive): +class RmsNormBwdPrimitive(BasePrimitive): """ - Dgated Gelu Cast Transpose Primitive + RMS Normalization Backward Primitive """ - name = "te_dgated_gelu_cast_transpose" + name = "te_rmsnorm_backward" multiple_results = True + impl_static_args = (4,) # epsilon + inner_primitive = None + outer_primitive = None @staticmethod - def abstract(inputs, gelu_inputs, amax, scale, scale_inv, *, out_dtype): + def abstract( + dz_aval, + x_aval, + rsigma_aval, + gamma_aval, + **kwargs # pylint: disable=unused-argument + ): """ - te_dgated_gelu_cast_transpose_p abstract + RMSNorm bwd abstract """ - dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert gelu_inputs.dtype == dtype - assert len(inputs.shape) == 2 - assert len(gelu_inputs.shape) == 2 - ir_batch_szie = inputs.shape[0] - gi_batch_size = gelu_inputs.shape[0] - assert ir_batch_szie == gi_batch_size - ir_hidden_szie = inputs.shape[1] - gi_hidden_size = gelu_inputs.shape[1] - assert ir_hidden_szie * 2 == gi_hidden_size - assert amax.dtype == jnp.float32 - assert scale.dtype == jnp.float32 - assert scale_inv.dtype == jnp.float32 - out_dtype = te_dtype_to_jax_dtype(out_dtype) - # input_cast, input_cast_trans, amax - return (ShapedArray((gi_batch_size, gi_hidden_size), - out_dtype, - named_shape=inputs.named_shape), - ShapedArray((gi_hidden_size, gi_batch_size), - out_dtype, - named_shape=inputs.named_shape), - ShapedArray((1,), amax.dtype, named_shape=amax.named_shape)) - - @staticmethod - def lowering(ctx, inputs, gelu_inputs, amax, scale, scale_inv, *, out_dtype): + w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) + rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) + + assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype + assert dz_aval.shape == x_aval.shape + assert rsigma_aval.shape == x_aval.shape[:-1] + assert rsigma_dtype == jnp.float32 + + dx_aval = core.raise_to_shaped(dz_aval) + dgamma_aval = core.raise_to_shaped(gamma_aval) + return dx_aval, dgamma_aval + + @staticmethod + def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): """ - te_dgated_gelu_cast_transpose_p lowering rules + RMSNorm bwd lowering rules """ - in_aval, gi_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert gi_aval.dtype == in_aval.dtype - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_in_type = ir.RankedTensorType(inputs.type) - ir_in_shape = ir_in_type.shape - gi_type = ir.RankedTensorType(gelu_inputs.type) - gi_shape = gi_type.shape - ir_batch_szie = ir_in_shape[0] - gi_batch_size = gi_shape[0] - assert ir_batch_szie == gi_batch_size - ir_hidden_szie = ir_in_shape[1] - gi_hidden_size = gi_shape[1] - assert ir_hidden_szie * 2 == gi_hidden_size - ir_out_dtype = te_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape + _, x_aval, _, gamma_aval = ctx.avals_in + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + dz_shape = ir.RankedTensorType(dz.type).shape + rsigma_shape = ir.RankedTensorType(rsigma.type).shape + + hidden_size = reduce(operator.mul, g_shape) + batch_size = reduce(operator.mul, x_shape) // hidden_size out_types = [ - ir.RankedTensorType.get([gi_batch_size, gi_hidden_size], ir_out_dtype), - ir.RankedTensorType.get([gi_hidden_size, gi_batch_size], ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(x_shape, x_type.element_type), + ir.RankedTensorType.get(g_shape, g_type.element_type), ] - operands = [inputs, gelu_inputs, amax, scale, scale_inv] - operand_shapes = [ir_in_shape, gi_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + operands = [dz, rsigma, x, gamma] + operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_szie, ir_hidden_szie), - jax_dtype_to_te_dtype(in_aval.dtype), - out_dtype) + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + False, # RMSNorm doesn't support zero_centered_gamma + epsilon, + ) - out = custom_caller(DgatedGeluCastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={2: 2}) + out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False) return out + @staticmethod + def impl(dz, x, rsigma, gamma, epsilon): + assert RmsNormBwdPrimitive.inner_primitive is not None + dx, dgamma = RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) + return dx, dgamma + + @staticmethod + def batcher(batched_args, batch_dims, *, epsilon): + _check_valid_batch_dims(batch_dims) + assert RmsNormBwdPrimitive.outer_primitive is not None + dz, x, rsigma, gamma = batched_args + _, x_bdim, _, gamma_bdim = batch_dims + + out_bdims = x_bdim, gamma_bdim + return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, + epsilon=epsilon), out_bdims + + @staticmethod + def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos): + del epsilon, result_infos + x_spec = get_padded_spec(arg_infos[1]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + g_spec = get_padded_spec(arg_infos[3]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) + return dx_sharding, dgamma_sharding + + @staticmethod + def partition(epsilon, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[1]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + g_spec = get_padded_spec(arg_infos[3]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) + out_shardings = dx_sharding, dgamma_sharding + x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. + rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) + arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec))) + + def sharded_impl(dz, x, rsigma, gamma): + local_dx, local_dgamma = \ + RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon) + global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) + return local_dx, global_dgamma + + return mesh, sharded_impl, out_shardings, arg_shardings + -_dgated_gelu_cast_transpose_p = register_primitive(DgatedGeluCastTransposePrimitive) +register_primitive(RmsNormBwdPrimitive) -def dgated_gelu_cast_transpose(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray, + epsilon: float): """ - cast transpose d_gated_gelu fusion wrapper - Return FP8(dgeglu(inputs)) + Wrapper for TE layernorm bwd """ - return _dgated_gelu_cast_transpose_p.bind(inputs, - gelu_inputs, - amax, - scale, - scale_inv, - out_dtype=out_dtype) + return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) -class GemmPrimitive(BasePrimitive): +class SoftmaxPrimitive(BasePrimitive): """ - Gemm Primitive + Softmax Primitive """ - name = "te_gemm" - multiple_results = False + max_k_seqlen_supported = 4096 @staticmethod - def abstract(A, B, A_scale_inv, B_scale_inv, *, A_dtype, B_dtype, D_dtype, transa, transb, - use_split_accumulator): # pylint: disable=unused-argument - """ - te_gemm_p abstract - """ - atype = dtypes.canonicalize_dtype(A.dtype) - btype = dtypes.canonicalize_dtype(B.dtype) - assert atype == te_dtype_to_jax_dtype(A_dtype) - assert btype == te_dtype_to_jax_dtype(B_dtype) - assert A_scale_inv.dtype == jnp.float32 - assert B_scale_inv.dtype == jnp.float32 + @abstractmethod + def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, + dtype: jnp.dtype) -> bool: + """Check Softmax kernel availability based on size""" + raise NotImplementedError - m = A.shape[0] if transa else A.shape[1] - k = A.shape[1] if transa else A.shape[0] - n = B.shape[1] if transb else B.shape[0] - assert (transb and k == B.shape[0]) or k == B.shape[1] + @staticmethod + def get_batch_per_block(k_seqlen: int) -> int: + """Get batch per CTA in Softmax kernels""" + threads_per_warp = 32 + threads_per_block = 128 # Depends on the kernel implmentation - out_dtype = te_dtype_to_jax_dtype(D_dtype) - return ShapedArray((n, m), - out_dtype, - named_shape=merge_named_shape(A.named_shape, B.named_shape)) + pow2 = 1 << (k_seqlen - 1).bit_length() + warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp + batches_per_warp = 2 if pow2 <= 128 else 1 + warps_per_block = threads_per_block // warp_size + batches_per_block = warps_per_block * batches_per_warp + return batches_per_block @staticmethod - def lowering(ctx, A, B, A_scale_inv, B_scale_inv, *, A_dtype, B_dtype, D_dtype, transa, transb, - use_split_accumulator): + def forward_abstract(logits_aval, scale_factor): """ - te_gemm_p lowering rules + softmax_forward abstract """ - A_aval, B_aval, A_scale_inv_aval, B_scale_inv_aval = ctx.avals_in - assert A_aval.dtype == te_dtype_to_jax_dtype(A_dtype) - assert B_aval.dtype == te_dtype_to_jax_dtype(B_dtype) - assert A_scale_inv_aval.dtype == jnp.float32 - assert B_scale_inv_aval.dtype == jnp.float32 - A_type = ir.RankedTensorType(A.type) - B_type = ir.RankedTensorType(B.type) - A_shape = A_type.shape - B_shape = B_type.shape - A_scale_inv_shape = ir.RankedTensorType(A_scale_inv.type).shape - B_scale_inv_shape = ir.RankedTensorType(B_scale_inv.type).shape + del scale_factor + i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) + assert i_dtype in [jnp.float16, jnp.bfloat16] + i_shape = logits_aval.shape + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + q_seqlen = i_shape[-2] + k_seqlen = i_shape[-1] + assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported + assert q_seqlen > 1 - m = A_shape[0] if transa else A_shape[1] - k = A_shape[1] if transa else A_shape[0] - n = B_shape[1] if transb else B_shape[0] - assert (transb and k == B_shape[0]) or k == B_shape[1] + out_aval = core.raise_to_shaped(logits_aval) + return out_aval - ir_out_dtype = dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(D_dtype))) - out_types = [ - ir.RankedTensorType.get([n, m], ir_out_dtype), - ] - operands = [A, B, A_scale_inv, B_scale_inv] - operand_shapes = [A_shape, B_shape, A_scale_inv_shape, B_scale_inv_shape] + @staticmethod + def forward_lowering(name, ctx, logits, *, scale_factor): + """ + softmax_forward lowering rules + """ + i_aval, = ctx.avals_in + i_type = ir.RankedTensorType(logits.type) + i_shape = i_type.shape + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, i_shape[:-3]) + pad_batch = batch + heads = i_shape[-3] + q_seqlen = i_shape[-2] + k_seqlen = i_shape[-1] + + out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] + operands = [logits] + operand_shapes = [i_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - # m, n, k here should be equal to transa=False and transb=False, - # due to te_gemm's implementation. - # Therefore, m=A_shape[1], n=B_shape[0], k=A_shape[0] - opaque = transformer_engine_jax.pack_gemm_descriptor(A_shape[1], B_shape[0], A_shape[0], - A_dtype, B_dtype, D_dtype, transa, - transb, use_split_accumulator) + opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(i_aval.dtype), + scale_factor) - out = custom_caller(GemmPrimitive.name, args, opaque, False) + out = custom_caller(name, args, opaque, False) return [out] + @staticmethod + def forward_impl(primitive, logits, scale_factor): + """ + softmax_forward implementation + """ + assert primitive is not None + output = primitive.bind(logits, scale_factor=scale_factor) + return output -_gemm_p = register_primitive(GemmPrimitive) - - -def gemm(A: jnp.ndarray, - A_scale_inv: jnp.ndarray, - A_type: TEDType, - transa: bool, - B: jnp.ndarray, - B_scale_inv: jnp.ndarray, - B_type: TEDType, - transb: bool, - D_type: TEDType, - use_split_accumulator: bool = False) -> jnp.ndarray: - """ - gemm wrapper - """ - return _gemm_p.bind(A, - B, - A_scale_inv, - B_scale_inv, - A_dtype=A_type, - B_dtype=B_type, - D_dtype=D_type, - transa=transa, - transb=transb, - use_split_accumulator=use_split_accumulator) + @staticmethod + def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor): + """ + softmax_forward batcher + """ + assert primitive is not None + logits, = batched_args + logits_bdim, = batch_dims + out_bdims = logits_bdim + return primitive.bind(logits, scale_factor=scale_factor), out_bdims -class LayerNormFwdPrimitive(BasePrimitive): - """ - Layer Normalization Forward Primitive - """ - name = "te_layernorm_forward" - multiple_results = True + @staticmethod + def forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): + """ + softmax_forward infer_sharding_from_operands + """ + del scale_factor, result_infos # Unused. + logits_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + return out_sharding @staticmethod - def abstract(x, gamma, beta, **kwargs): # pylint: disable=unused-argument + def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos): """ - LayerNorm fwd abstract + softmax_forward partitioning """ - x_dtype = dtypes.canonicalize_dtype(x.dtype) - assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + del result_infos + logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0]))) + out_spec = logits_spec + arg_shardings = (logits_spec,) + out_shardings = out_spec + impl = partial(impl, scale_factor=scale_factor) + return mesh, impl, out_shardings, arg_shardings - mu_dtype = jnp.float32 - rsigma_dtype = jnp.float32 + @staticmethod + def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None): # pylint: disable=unused-argument + """ + softmax_backward abstract + """ + dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype) + assert dz_dtype == softmax_out_dtype + assert dz_dtype in [jnp.float16, jnp.bfloat16] + assert softmax_out_dtype in [jnp.float16, jnp.bfloat16] - assert gamma.size == beta.size - hidden_size = gamma.size - assert x.size % hidden_size == 0 - # In Transformer, batch_size = batch x seqlen - batch_size = x.size // hidden_size + assert dz_aval.shape == softmax_out_aval.shape - return ( - ShapedArray(x.shape, x_dtype, named_shape=x.named_shape), # output - ShapedArray((batch_size,), mu_dtype, named_shape=x.named_shape), # mu - ShapedArray((batch_size,), rsigma_dtype, named_shape=x.named_shape), # rsigma - ) + dx_aval = core.raise_to_shaped(softmax_out_aval) + return dx_aval @staticmethod - def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): + def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor): """ - LayerNorm fwd lowering rules + softmax_backward lowering rules """ - x_aval, gamma_aval, beta_aval = ctx.avals_in - assert gamma_aval.dtype == beta_aval.dtype - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - w_type = ir.RankedTensorType(gamma.type) - w_shape = w_type.shape - b_type = ir.RankedTensorType(beta.type) - b_shape = b_type.shape + dz_aval, _ = ctx.avals_in - assert w_type == b_type - assert w_shape == b_shape + dz_type = ir.RankedTensorType(dz.type) + dz_shape = dz_type.shape - # Output shape is same as the input shape, but the output type is same as the weight type. - # See ln_api.cpp - out_shape = x_shape - output_type = w_type.element_type - ir_mu_dtype = ir.F32Type.get() - ir_rsigma_dtype = ir.F32Type.get() + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, dz_shape[:-3]) + pad_batch = batch # unused + heads = dz_shape[-3] + q_seqlen = dz_shape[-2] + k_seqlen = dz_shape[-1] - hidden_size = reduce(operator.mul, w_shape) - # In Transformer, batch_size = batch x seqlen - batch_size = reduce(operator.mul, x_shape) // hidden_size + softmax_out_type = ir.RankedTensorType(softmax_out.type) + softmax_out_shape = softmax_out_type.shape - out_types = [ - ir.RankedTensorType.get(out_shape, output_type), - ir.RankedTensorType.get((batch_size,), ir_mu_dtype), - ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype), - ] - operands = [x, gamma, beta] - operand_shapes = [x_shape, w_shape, b_shape] + out_types = [ir.RankedTensorType.get(softmax_out_shape, softmax_out_type.element_type)] + operands = [dz, softmax_out] + operand_shapes = [dz_shape, softmax_out_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - zero_centered_gamma, - epsilon, - ) - - out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False) - - return out - - -_layernorm_fwd_p = register_primitive(LayerNormFwdPrimitive) - - -def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, - epsilon: float): - """ - Wrapper for TE layernorm fwd - """ - return _layernorm_fwd_p.bind(x, - gamma, - beta, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(dz_aval.dtype), + scale_factor) + out = custom_caller(name, args, opaque, False) -class LayerNormFwdFp8Primitive(BasePrimitive): - """ - Layer Normalization Forward FP8 Primitive - """ - name = "te_layernorm_forward_fp8" - multiple_results = True + return [out] @staticmethod - def abstract( - x, - gamma, - beta, - amax, - scale, - scale_inv, - **kwargs # pylint: disable=unused-argument - ): + def backward_impl(primitive, dz, softmax_out, scale_factor): """ - LayerNorm fwd (fp8 out) abstract + softmax_backward implementation """ - x_dtype = dtypes.canonicalize_dtype(x.dtype) - - assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax.dtype == jnp.float32 - assert scale.dtype == jnp.float32 - assert scale_inv.dtype == jnp.float32 + assert primitive is not None + dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor) + return dx - out_dtype = jnp.int8 - mu_dtype = jnp.float32 - rsigma_dtype = jnp.float32 - - assert gamma.size == beta.size - - hidden_szie = gamma.size - # In Transformer, batch_size = batch x seqlen - batch_size = x.size // hidden_szie + @staticmethod + def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor): + """ + softmax_backward batcher + """ + assert primitive is not None + dz, softmax_out = batched_args + _, softmax_out_bdim = batch_dims - return ( - ShapedArray(x.shape, out_dtype, named_shape=x.named_shape), # output - ShapedArray((batch_size,), mu_dtype, named_shape=x.named_shape), # mu - ShapedArray((batch_size,), rsigma_dtype, named_shape=x.named_shape), # rsigma - ShapedArray((1,), amax.dtype, named_shape=amax.named_shape), # amax - ) + out_bdims = softmax_out_bdim + return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims @staticmethod - def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, zero_centered_gamma, epsilon): + def backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): """ - LayerNorm fwd (fp8 out) lowering rules + softmax_backward infer_sharding_from_operands """ - x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + del scale_factor, result_infos # Unused. + softmax_out_spec = get_padded_spec(arg_infos[1]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec)) + return dx_sharding - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert gamma_aval.dtype == beta_aval.dtype - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 + @staticmethod + def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos): + """ + softmax_backward partition + """ + del result_infos + dz_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0]))) + softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + dx_spec = softmax_out_spec + arg_shardings = (dz_spec, softmax_out_spec) + out_shardings = dx_spec + impl = partial(impl, scale_factor=scale_factor) + return mesh, impl, out_shardings, arg_shardings - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - w_type = ir.RankedTensorType(gamma.type) - w_shape = w_type.shape - b_type = ir.RankedTensorType(beta.type) - b_shape = b_type.shape - ir_out_dtype = dtype_to_ir_type(np.dtype(np.int8)) - ir_mu_dtype = ir.F32Type.get() - ir_rsigma_dtype = ir.F32Type.get() - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape +class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): + """ + Scaled Softmax Fwd Primitive + """ + name = "te_scaled_softmax_forward" + multiple_results = False + impl_static_args = (1,) # scale_factor + inner_primitive = None + outer_primitive = None - hidden_size = reduce(operator.mul, w_shape) - # In Transformer, batch_size = batch x seqlen - batch_size = reduce(operator.mul, x_shape) // hidden_size + @staticmethod + def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, + dtype: jnp.dtype) -> bool: + """Check Softmax kernel availability based on size""" + attn_batches = batch * heads - out_types = [ - ir.RankedTensorType.get(x_shape, ir_out_dtype), - ir.RankedTensorType.get((batch_size,), ir_mu_dtype), - ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, gamma, beta, amax, scale, scale_inv] - operand_shapes = [ - x_shape, w_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape - ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + dtype = dtypes.canonicalize_dtype(dtype) + if (dtype in [jnp.float16, jnp.bfloat16] + and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported + # k_seqlen must be 16 ~ 4096 + and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 + and attn_batches % 4 == 0 # batch * heads must be divisor of 4 + ): + if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: + batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) + return q_seqlen % batch_per_block == 0 + return False - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - zero_centered_gamma, - epsilon, - ) + @staticmethod + def abstract(logits_aval, scale_factor): # pylint: disable=unused-argument + """ + te_scaled_softmax_forward abstract + """ + return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor) - out = custom_caller(LayerNormFwdFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={3: 3}) + @staticmethod + def lowering(ctx, logits, *, scale_factor): + """ + te_scaled_softmax_forward lowering rules + """ + return SoftmaxPrimitive.forward_lowering(ScaledSoftmaxFwdPrimitive.name, + ctx, + logits, + scale_factor=scale_factor) - return out + @staticmethod + def impl(logits, scale_factor): + return SoftmaxPrimitive.forward_impl(ScaledSoftmaxFwdPrimitive.inner_primitive, logits, + scale_factor) + + @staticmethod + def batcher(batched_args, batch_dims, *, scale_factor): + _check_valid_batch_dims(batch_dims) + return SoftmaxPrimitive.forward_batcher(ScaledSoftmaxFwdPrimitive.outer_primitive, + batched_args, + batch_dims, + scale_factor=scale_factor) + @staticmethod + def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, + result_infos) + + @staticmethod + def partition(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, scale_factor, + mesh, arg_infos, result_infos) -_layernorm_fwd_fp8_p = register_primitive(LayerNormFwdFp8Primitive) +register_primitive(ScaledSoftmaxFwdPrimitive) -def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray, zero_centered_gamma: bool, - epsilon: float): + +def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: """ - Wrapper for TE layernorm fwd (fp8 out) + scaled_softmax_forward wrapper + Return FP16/BF16 tensor """ - return _layernorm_fwd_fp8_p.bind(x, - gamma, - beta, - amax, - scale, - scale_inv, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) -class LayerNormBwdPrimitive(BasePrimitive): +class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): """ - Layer Normalization Backward Primitive + Scaled Softmax Bwd Primitive """ - name = "te_layernorm_backward" - multiple_results = True + name = "te_scaled_softmax_backward" + multiple_results = False + impl_static_args = (2,) # scale_factor + inner_primitive = None + outer_primitive = None + + @staticmethod + def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, + dtype: jnp.dtype) -> bool: + """Check Softmax kernel availability based on size""" + return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen, + dtype) @staticmethod - def abstract(grad_output, mu, rsigma, x, gamma, **kwargs): # pylint: disable=unused-argument + def abstract(dz_aval, softmax_out_aval, scale_factor): """ - Layernorm bwd abstract + te_scaled_softmax_backward abstract """ - x_dtype = dtypes.canonicalize_dtype(x.dtype) - w_dtype = dtypes.canonicalize_dtype(gamma.dtype) - mu_dtype = dtypes.canonicalize_dtype(mu.dtype) - rsigma_dtype = dtypes.canonicalize_dtype(rsigma.dtype) + return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor) - hidden_size = gamma.size - # In Transformer, batch_size = batch x seqlen - batch_size = x.size // hidden_size + @staticmethod + def lowering(ctx, dz, softmax_out, *, scale_factor): + """ + te_scaled_softmax_backward lowering rules + """ + out = SoftmaxPrimitive.backward_lowering(ScaledSoftmaxBwdPrimitive.name, + ctx, + dz, + softmax_out, + scale_factor=scale_factor) - assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype - assert grad_output.shape == x.shape - assert mu.shape == rsigma.shape == (batch_size,) - assert mu_dtype == rsigma_dtype == jnp.float32 - assert grad_output.named_shape == x.named_shape + return out - return ( - ShapedArray(x.shape, x_dtype, named_shape=grad_output.named_shape), # grad input - ShapedArray(gamma.shape, w_dtype, named_shape=gamma.named_shape), # grad gamma - ShapedArray(gamma.shape, w_dtype, named_shape=gamma.named_shape), # grad beta - ) + @staticmethod + def impl(dz, softmax_out, scale_factor): + return SoftmaxPrimitive.backward_impl(ScaledSoftmaxBwdPrimitive.inner_primitive, + dz, + softmax_out, + scale_factor=scale_factor) + + @staticmethod + def batcher(batched_args, batch_dims, *, scale_factor): + _check_valid_batch_dims(batch_dims) + return SoftmaxPrimitive.backward_batcher(ScaledSoftmaxBwdPrimitive.outer_primitive, + batched_args, + batch_dims, + scale_factor=scale_factor) + + @staticmethod + def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, + result_infos) @staticmethod - def lowering(ctx, grad_output, mu, rsigma, x, gamma, *, zero_centered_gamma, epsilon): + def partition(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, scale_factor, + mesh, arg_infos, result_infos) + + +register_primitive(ScaledSoftmaxBwdPrimitive) + + +def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, + scale_factor: float) -> jnp.ndarray: + """ + scaled_backward wrapper + Return FP16/BF16 tensor + """ + return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(dz, + softmax_out, + scale_factor=scale_factor) + + +class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): + """ + Scaled Masked Softmax Fwd Primitive + """ + name = "te_scaled_masked_softmax_forward" + multiple_results = False + impl_static_args = (2,) # scale_factor + inner_primitive = None + outer_primitive = None + + @staticmethod + def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, + dtype: jnp.dtype) -> bool: + """Check Softmax kernel availability based on size""" + attn_batches = batch * heads + + dtype = dtypes.canonicalize_dtype(dtype) + if (dtype in [jnp.float16, jnp.bfloat16] + and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported + # k_seqlen must be 16 ~ 4096 + and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 + and attn_batches % 4 == 0 # batch * heads must be divisor of 4 + ): + if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: + batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) + return q_seqlen % batch_per_block == 0 + return False + + @staticmethod + def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-argument """ - Layernorm bwd lowering rules + te_scaled_masked_softmax_forward abstract """ - _, _, _, x_aval, gamma_aval = ctx.avals_in - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - w_type = ir.RankedTensorType(gamma.type) - w_shape = w_type.shape - b_type = ir.RankedTensorType(gamma.type) - b_shape = b_type.shape - assert w_type == b_type - assert w_shape == b_shape - go_shape = ir.RankedTensorType(grad_output.type).shape - mu_shape = ir.RankedTensorType(mu.type).shape - rsigma_shape = ir.RankedTensorType(rsigma.type).shape + i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) + assert i_dtype in [jnp.float16, jnp.bfloat16] + i_shape = logits_aval.shape - hidden_size = reduce(operator.mul, w_shape) - # In Transformer, batch_size = batch x seqlen - batch_size = reduce(operator.mul, x_shape) // hidden_size + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, i_shape[:-3]) + q_seqlen = i_shape[-2] + k_seqlen = i_shape[-1] + assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported + assert q_seqlen > 1 - out_types = [ - ir.RankedTensorType.get(x_shape, x_type.element_type), - ir.RankedTensorType.get(w_shape, w_type.element_type), - ir.RankedTensorType.get(b_shape, b_type.element_type), + mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype) + assert mask_dtype in [ + jnp.uint8, ] - operands = [grad_output, mu, rsigma, x, gamma] - operand_shapes = [go_shape, mu_shape, rsigma_shape, x_shape, w_shape] + mask_shape = mask_aval.shape + pad_batch = batch = reduce(operator.mul, mask_shape[:-3]) + assert pad_batch in (1, batch) # 1 means broadcast + assert mask_shape[-3] == 1 # 1 means broadcast + assert mask_shape[-2] == q_seqlen + assert mask_shape[-1] == k_seqlen + + out_aval = core.raise_to_shaped(logits_aval) + return out_aval + + @staticmethod + def lowering(ctx, logits, mask, *, scale_factor): + """ + te_scaled_masked_softmax_forward lowering rules + """ + + logits_aval, _ = ctx.avals_in + i_type = ir.RankedTensorType(logits.type) + i_shape = i_type.shape + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, i_shape[:-3]) + heads = i_shape[-3] + q_seqlen = i_shape[-2] + k_seqlen = i_shape[-1] + + mask_type = ir.RankedTensorType(mask.type) + mask_shape = mask_type.shape + pad_batch = reduce(operator.mul, mask_shape[:-3]) + + out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] + operands = [logits, mask] + operand_shapes = [i_shape, mask_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - zero_centered_gamma, - epsilon, - ) + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(logits_aval.dtype), + scale_factor) - out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False) + out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False) + + return [out] + + @staticmethod + def impl(logits, mask, scale_factor): + assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None + output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(logits, + mask, + scale_factor=scale_factor) + return output + + @staticmethod + def batcher(batched_args, batch_dims, *, scale_factor): + _check_valid_batch_dims(batch_dims) + assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None + logits, mask = batched_args + logits_bdim, _ = batch_dims + + out_bdims = logits_bdim + return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( + logits, mask, scale_factor=scale_factor), out_bdims + + @staticmethod + def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): + del scale_factor, result_infos # Unused. + logits_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + return out_sharding + + @staticmethod + def partition(scale_factor, mesh, arg_infos, result_infos): + del result_infos + logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0]))) + mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = (logits_spec, mask_spec) + out_shardings = logits_spec + impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor) + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(ScaledMaskedSoftmaxFwdPrimitive) + + +def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray, + scale_factor: float) -> jnp.ndarray: + """ + scaled_masked_softmax_forward wrapper + Return FP16/BF16 tensor + """ + return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits, + mask, + scale_factor=scale_factor) + + +class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): + """ + Scaled Masked Softmax Bwd Primitive + """ + name = "te_scaled_masked_softmax_backward" + multiple_results = False + impl_static_args = (2,) # scale_factor + inner_primitive = None + outer_primitive = None + + @staticmethod + def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, + dtype: jnp.dtype) -> bool: + """Check Softmax kernel availability based on size""" + return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen, + dtype) + + @staticmethod + def abstract(dz_aval, softmax_out_aval, *, scale_factor): + """ + te_scaled_upper_triang_masked_backward abstract + """ + return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor) + + @staticmethod + def lowering(ctx, dz, softmax_out, *, scale_factor): + """ + te_scaled_upper_triang_masked_backward lowering rules + """ + out = SoftmaxPrimitive.backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name, + ctx, + dz, + softmax_out, + scale_factor=scale_factor) + + return out + + @staticmethod + def impl(dz, softmax_out, scale_factor): + return SoftmaxPrimitive.backward_impl(ScaledMaskedSoftmaxBwdPrimitive.inner_primitive, + dz, + softmax_out, + scale_factor=scale_factor) + + @staticmethod + def batcher(batched_args, batch_dims, *, scale_factor): + _check_valid_batch_dims(batch_dims) + return SoftmaxPrimitive.backward_batcher(ScaledMaskedSoftmaxBwdPrimitive.outer_primitive, + batched_args, + batch_dims, + scale_factor=scale_factor) + + @staticmethod + def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, + result_infos) + + @staticmethod + def partition(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.backward_partition(ScaledMaskedSoftmaxBwdPrimitive.impl, + scale_factor, mesh, arg_infos, result_infos) + + +register_primitive(ScaledMaskedSoftmaxBwdPrimitive) + + +def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, + scale_factor: float) -> jnp.ndarray: + """ + scaled_masked_backward wrapper + Return FP16/BF16 tensor + """ + return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(dz, + softmax_out, + scale_factor=scale_factor) + + +class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): + """ + Scaled Upper Triang Masked Softmax Fwd Primitive + """ + name = "te_scaled_upper_triang_masked_softmax_forward" + multiple_results = False + impl_static_args = (1,) # scale_factor + inner_primitive = None + outer_primitive = None + + @staticmethod + def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, + dtype: jnp.dtype) -> bool: + """Check Softmax kernel availability based on size""" + attn_batches = batch * heads + + dtype = dtypes.canonicalize_dtype(dtype) + if (dtype in [jnp.float16, jnp.bfloat16] + and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported + # k_seqlen must be 16 ~ 4096 + and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 + and attn_batches % 4 == 0 # batch * heads must be divisor of 4 + ): + if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: + batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) + return attn_batches % batch_per_block == 0 + return False + + @staticmethod + def abstract(logits_aval, scale_factor): # pylint: disable=unused-argument + """ + te_scaled_upper_triang_masked_softmax_forward abstract + """ + q_seqlen = logits_aval.shape[2] + k_seqlen = logits_aval.shape[3] + assert q_seqlen == k_seqlen + return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor) + + @staticmethod + def lowering(ctx, logits, *, scale_factor): + """ + te_scaled_upper_triang_masked_softmax_forward lowering rules + """ + return SoftmaxPrimitive.forward_lowering(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, + ctx, + logits, + scale_factor=scale_factor) + + @staticmethod + def impl(logits, scale_factor): + return SoftmaxPrimitive.forward_impl( + ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor) + + @staticmethod + def batcher(batched_args, batch_dims, *, scale_factor): + _check_valid_batch_dims(batch_dims) + return SoftmaxPrimitive.forward_batcher( + ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive, + batched_args, + batch_dims, + scale_factor=scale_factor) + + @staticmethod + def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, + result_infos) + + @staticmethod + def partition(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.forward_partition(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, + scale_factor, mesh, arg_infos, result_infos) + + +register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) + + +def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: + """ + scaled_upper_triang_masked_softmax_forward wrapper + Return FP16/BF16 tensor + """ + return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( + logits, scale_factor=scale_factor) + + +class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): + """ + Scaled Upper Triang Masked Softmax Bwd Primitive + """ + name = "te_scaled_upper_triang_masked_softmax_backward" + multiple_results = False + impl_static_args = (2,) # scale_factor + inner_primitive = None + outer_primitive = None + + @staticmethod + def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, + dtype: jnp.dtype) -> bool: + """Check Softmax kernel availability based on size""" + return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available( + batch, heads, q_seqlen, k_seqlen, dtype) + + @staticmethod + def abstract(dz_aval, softmax_out_aval, *, scale_factor): + """ + te_scaled_upper_triang_masked_backward abstract + """ + return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor) + + @staticmethod + def lowering(ctx, dz, softmax_out, *, scale_factor): + """ + te_scaled_upper_triang_masked_backward lowering rules + """ + out = SoftmaxPrimitive.backward_lowering(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, + ctx, + dz, + softmax_out, + scale_factor=scale_factor) return out + @staticmethod + def impl(dz, softmax_out, scale_factor): + return SoftmaxPrimitive.backward_impl( + ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive, + dz, + softmax_out, + scale_factor=scale_factor) -_layernorm_bwd_p = register_primitive(LayerNormBwdPrimitive) + @staticmethod + def batcher(batched_args, batch_dims, *, scale_factor): + _check_valid_batch_dims(batch_dims) + return SoftmaxPrimitive.backward_batcher( + ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive, + batched_args, + batch_dims, + scale_factor=scale_factor) + @staticmethod + def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, + result_infos) -def layernorm_bwd(g: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray, x: jnp.ndarray, - gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float): + @staticmethod + def partition(scale_factor, mesh, arg_infos, result_infos): + return SoftmaxPrimitive.backward_partition(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, + scale_factor, mesh, arg_infos, result_infos) + + +register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) + + +def scaled_upper_triang_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, + scale_factor: float) -> jnp.ndarray: """ - Wrapper for TE layernorm bwd + scaled_upper_triang_masked_backward wrapper + Return FP16/BF16 tensor """ - return _layernorm_bwd_p.bind(g, - mu, - rsigma, - x, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( + dz, softmax_out, scale_factor=scale_factor) -class RmsNormFwdPrimitive(BasePrimitive): +@dataclass(frozen=True) +class FusedAttnHelper: """ - RMS Normalization Forward Primitive + Helper for the fused attention backend """ - name = "te_rmsnorm_forward" + + q_type: jnp.dtype + kv_type: jnp.dtype + qkv_layout: NVTE_QKV_Layout + attn_bias_type: NVTE_Bias_Type + attn_mask_type: NVTE_Mask_Type + dropout_probability: float + max_seqlen_q: int + max_seqlen_kv: int + head_dim: int + + def is_fused_attn_kernel_available(self): + """Check if there is available fused attention kernel""" + return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend + + def get_fused_attn_backend(self): + """Get the fused attention kernel backend""" + return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type), + jax_dtype_to_te_dtype(self.kv_type), + self.qkv_layout, self.attn_bias_type, + self.attn_mask_type, + self.dropout_probability, + self.max_seqlen_q, self.max_seqlen_kv, + self.head_dim) + + +@dataclass(frozen=True) +class _FusedAttnRNGStateChecker: + """ + Checker for guarding the fused attention rng state. + The fused attention backend requires a 64 bits seed and a 64 bits offset. + However, JAX doesn't enable 64 bits by default, + so we have to emulate seed as two 32 bits array. + The offset calculation is maintained in the backend. + """ + rng_state_dtype: jnp.dtype = jnp.uint32 + # (seed,) with internal dtype int64 + seed_size: int = 2 + # (seed, offset) with internal dtype int64 + rng_state_size: int = 2 * 2 + + def check_seed(self, seed, dropout_probability, is_training): + """ + Check the seed and convert the data type of seed if possible. + """ + # Jax can't bind None, create a dummy tensor for None + if seed is None: + dropout_enabled = dropout_probability > 0 and is_training + assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled." + seed = jnp.zeros(2, dtype=self.rng_state_dtype) + seed = jnp.repeat(seed, num_of_devices()) + + if seed.dtype != self.rng_state_dtype: + warnings.warn( + f"Requested {seed.dtype=} is not available, and will be " + f"casted to dtype {self.rng_state_dtype}. " + f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.") + seed = seed.astype(self.rng_state_dtype) + + assert seed.dtype == self.rng_state_dtype + # Backend takes an int64_t seed, so only the first two u32 elements are taken + assert seed.size >= self.seed_size + + return seed + + +def generate_cu_seqlen(mask): + """ + Generating cumsum seqlen for a batch + """ + seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32) + cu_seqlen = jnp.cumsum(seqlen) + cu_seqlen = jnp.hstack((0, cu_seqlen)) + return cu_seqlen + + +class SelfFusedAttnFwdPrimitive(BasePrimitive): + """ + Self Fused Attention Forward Primitive + """ + name = "te_self_fused_attn_forward" multiple_results = True + impl_static_args = (4, 5, 6, 7, 8) + inner_primitive = None + outer_primitive = None @staticmethod - def abstract(x, gamma, **kwargs): # pylint: disable=unused-argument + def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, is_training): """ - RMSNorm fwd abstract + Self fused attention fwd abstract """ - x_dtype = dtypes.canonicalize_dtype(x.dtype) - rsigma_dtype = jnp.float32 + # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen + del mask_or_cu_seqlen_aval, scaling_factor, is_training + qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) + *batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape + assert nqkv == 3 + assert qkv_aval.dtype == bias_aval.dtype - hidden_size = gamma.size - # In Transformer, batch_size = batch x seqlen - batch_size = x.size // hidden_size + output_shape = (*batch_shape, max_seqlen, num_head, head_dim) + output_dtype = qkv_dtype - return ( - ShapedArray(x.shape, x_dtype, named_shape=x.named_shape), # output - ShapedArray((batch_size,), rsigma_dtype, named_shape=x.named_shape), # rsigma - ) + backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, + attn_mask_type, dropout_probability, max_seqlen, max_seqlen, + head_dim).get_fused_attn_backend() + + if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: + softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen) + softmax_dtype = qkv_dtype + elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: + softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) + else: + raise ValueError(f'Not supported {backend=}') + + checker = _FusedAttnRNGStateChecker() + seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) + assert seed_dtype == checker.rng_state_dtype + rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) + rng_state_dtype = seed_dtype + + out_aval = qkv_aval.update(shape=output_shape, dtype=output_dtype) + softmax_aux_aval = qkv_aval.update(shape=softmax_aux_shape, dtype=softmax_dtype) + rng_state_aval = qkv_aval.update(shape=rng_state_shape, dtype=rng_state_dtype) + return out_aval, softmax_aux_aval, rng_state_aval @staticmethod - def lowering(ctx, x, gamma, *, epsilon): + def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): """ - RMSNorm fwd lowering rules + Self fused attention fwd lowering rules """ - x_aval, gamma_aval = ctx.avals_in - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - w_type = ir.RankedTensorType(gamma.type) - w_shape = w_type.shape - iv_element_type = ir.F32Type.get() + qkv_aval, _, _, _ = ctx.avals_in - hidden_size = reduce(operator.mul, w_shape) - # In Transformer, batch_size = batch x seqlen - batch_size = reduce(operator.mul, x_shape) // hidden_size + *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape + batch = reduce(operator.mul, batch_shape) + operands = [qkv, bias, cu_seqlen, seed] + operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ - ir.RankedTensorType.get(x_shape, w_type.element_type), - ir.RankedTensorType.get((batch_size,), iv_element_type), + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out ] - operands = [x, gamma] - operand_shapes = [x_shape, w_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, + attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - False, # RMSNorm doesn't support zero_centered_gamma - epsilon, - ) + out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) - out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False) + return out + + @staticmethod + def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + assert SelfFusedAttnFwdPrimitive.inner_primitive is not None - return out + cu_seqlen = generate_cu_seqlen(squeezed_mask) + output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind( + qkv, + bias, + cu_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return output, softmax_aux, rng_state -_rmsnorm_fwd_p = register_primitive(RmsNormFwdPrimitive) + @staticmethod + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + _check_valid_batch_dims(batch_dims) + assert SelfFusedAttnFwdPrimitive.outer_primitive is not None + qkv, bias, cu_seqlen, seed = batched_args + qkv_bdim, _, _, seed_bdim = batch_dims + + out_bdims = qkv_bdim, qkv_bdim, seed_bdim + return SelfFusedAttnFwdPrimitive.outer_primitive.bind( + qkv, + bias, + cu_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims + @staticmethod + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, mesh, arg_infos, + result_infos): + del attn_bias_type, attn_mask_type, scaling_factor + del dropout_probability, is_training, result_infos + x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:])) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None)) + rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) + return (out_sharding, softmax_aux_sharding, rng_state_sharding) -def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float): + @staticmethod + def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, + mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:])) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None)) + rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding]) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + impl = partial(SelfFusedAttnFwdPrimitive.impl, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(SelfFusedAttnFwdPrimitive) + + +def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, scaling_factor: float, + dropout_probability: float, is_training: bool): """ - Wrapper for TE rmsnorm fwd + Wrapper for TE self fused attention fwd + Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ - return _rmsnorm_fwd_p.bind(x, gamma, epsilon=epsilon) + checker = _FusedAttnRNGStateChecker() + seed = checker.check_seed(seed, dropout_probability, is_training) + + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=qkv.dtype) + return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv, + bias, + squeezed_mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) -class RmsNormFwdFp8Primitive(BasePrimitive): +class SelfFusedAttnBwdPrimitive(BasePrimitive): """ - RMS Normalization Forward FP8 Primitive + Self Fused Attention Backward Primitive """ - name = "te_rmsnorm_forward_fp8" + name = "te_self_fused_attn_backward" multiple_results = True + impl_static_args = (6, 7, 8, 9, 10) + inner_primitive = None + outer_primitive = None @staticmethod - def abstract( - x, - gamma, - amax, - scale, - scale_inv, - **kwargs # pylint: disable=unused-argument - ): + def abstract(qkv_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval, + mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): """ - RMSNorm fwd (fp8 out) abstract + Self fused attention bwd abstract """ - x_dtype = dtypes.canonicalize_dtype(x.dtype) - - assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax.dtype == jnp.float32 - assert scale.dtype == jnp.float32 - assert scale_inv.dtype == jnp.float32 - - out_dtype = jnp.int8 - rsigma_dtype = jnp.float32 + del softmax_aux_aval, rng_state_aval + # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen + del mask_or_cu_seqlen_aval, attn_mask_type + del scaling_factor, dropout_probability, is_training + qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) + assert qkv_aval.dtype == output_aval.dtype == doutput_aval.dtype + *batch_shape, max_seqlen, num_head, _ = output_aval.shape - hidden_size = gamma.size - # In Transformer, batch_size = batch x seqlen - batch_size = x.size // hidden_size + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + bias_shape = (0,) + else: + bias_shape = (*batch_shape[:-1], 1, num_head, max_seqlen, max_seqlen) + bias_dtype = qkv_dtype - return ( - ShapedArray(x.shape, out_dtype, named_shape=x.named_shape), # output - ShapedArray((batch_size,), rsigma_dtype, named_shape=x.named_shape), # rsigma - ShapedArray((1,), amax.dtype, named_shape=amax.named_shape), # amax - ) + dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype) + dbias = qkv_aval.update(shape=bias_shape, dtype=bias_dtype) + return dqkv_aval, dbias @staticmethod - def lowering(ctx, x, gamma, amax, scale, scale_inv, *, epsilon): + def lowering(ctx, qkv, softmax_aux, rng_state, output, doutput, cu_seqlen, *, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, is_training): """ - RMSNorm fwd (fp8 out) lowering rules + Self fused attention bwd lowering rules """ - x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - w_type = ir.RankedTensorType(gamma.type) - w_shape = w_type.shape - - ir_out_dtype = dtype_to_ir_type(np.dtype(np.int8)) - ir_rsigma_dtype = ir.F32Type.get() - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape + qkv_aval, _, _, _, _, _ = ctx.avals_in - hidden_size = reduce(operator.mul, w_shape) - # In Transformer, batch_size = batch x seqlen - batch_size = reduce(operator.mul, x_shape) // hidden_size + *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape + batch = reduce(operator.mul, batch_shape) + operands = [qkv, softmax_aux, rng_state, output, doutput, cu_seqlen] + operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ - ir.RankedTensorType.get(x_shape, ir_out_dtype), - ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out ] - operands = [x, gamma, amax, scale, scale_inv] - operand_shapes = [x_shape, w_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - False, # RMSNorm doesn't support zero_centered_gamma - epsilon, - ) + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, + attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) - out = custom_caller(RmsNormFwdFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={2: 2}) + out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) return out + @staticmethod + def impl(qkv, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, is_training): + assert SelfFusedAttnBwdPrimitive.inner_primitive is not None + + cu_seqlen = generate_cu_seqlen(squeezed_mask) -_rmsnorm_fwd_fp8_p = register_primitive(RmsNormFwdFp8Primitive) + dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind( + qkv, + softmax_aux, + rng_state, + output, + doutput, + cu_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return dqkv, dbias + @staticmethod + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + _check_valid_batch_dims(batch_dims) + assert SelfFusedAttnBwdPrimitive.outer_primitive is not None + qkv, softmax_aux, rng_state, output, doutput, cu_seqlen = batched_args + qkv_bdim, *_ = batch_dims + + out_bdims = qkv_bdim, qkv_bdim + return SelfFusedAttnBwdPrimitive.outer_primitive.bind( + qkv, + softmax_aux, + rng_state, + output, + doutput, + cu_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims -def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, - scale_inv: jnp.ndarray, epsilon: float): + @staticmethod + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, mesh, arg_infos, + result_infos): + del attn_mask_type, scaling_factor, dropout_probability, + del is_training, result_infos + x_spec = get_padded_spec(arg_infos[0]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + dbias_spec = [None] + if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + dbias_spec = [*x_spec[:-5], None, x_spec[-2], None, None] + dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_spec)) + return (dx_sharding, dbias_sharding) + + @staticmethod + def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, + mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + dbias_spec = [None] + if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + dbias_spec = [*x_spec[:-5], None, x_spec[-2], None, None] + dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dx_sharding, dbias_sharding) + + def sharded_impl(qkv, softmax_aux, rng_state, output, doutput, cu_seqlen): + local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl( + qkv, + softmax_aux, + rng_state, + output, + doutput, + cu_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + global_dbias = local_dbias + if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) + return local_dx, global_dbias + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(SelfFusedAttnBwdPrimitive) + + +def self_fused_attn_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, + output: jnp.ndarray, doutput: jnp.ndarray, squeezed_mask: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, dropout_probability: float, is_training: bool): """ - Wrapper for TE rmsnorm fwd (fp8 out) + Wrapper for TE self fused attention bwd + Return the gradients of self fused attention with packed qkv input """ - return _rmsnorm_fwd_fp8_p.bind(x, gamma, amax, scale, scale_inv, epsilon=epsilon) + return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv, + softmax_aux, + rng_state, + output, + doutput, + squeezed_mask, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) -class RmsNormBwdPrimitive(BasePrimitive): +class CrossFusedAttnFwdPrimitive(BasePrimitive): """ - RMS Normalization Backward Primitive + Cross Fused Attention Forward Primitive """ - name = "te_rmsnorm_backward" + name = "te_cross_fused_attn_forward" multiple_results = True + impl_static_args = (5, 6, 7, 8, 9) + inner_primitive = None + outer_primitive = None @staticmethod - def abstract( - grad_output, - rsigma, - x, - gamma, - **kwargs # pylint: disable=unused-argument - ): + def abstract(q_aval, kv_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval, seed_aval, *, + attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): """ - RMSNorm bwd abstract + Cross fused attention fwd abstract """ - w_dtype = dtypes.canonicalize_dtype(gamma.dtype) - x_dtype = dtypes.canonicalize_dtype(x.dtype) - rsigma_dtype = dtypes.canonicalize_dtype(rsigma.dtype) + del seed_aval, attn_bias_type, attn_mask_type + del scaling_factor, dropout_probability, is_training - hidden_size = gamma.size - # In Transformer, batch_size = batch x seqlen - batch_size = x.size // hidden_size + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + *q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape - assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype - assert grad_output.shape == x.shape - assert rsigma.shape == (batch_size,) - assert rsigma_dtype == jnp.float32 - assert grad_output.named_shape == x.named_shape + kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) + *kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape - return ( - ShapedArray(x.shape, x_dtype, named_shape=grad_output.named_shape), # grad input - ShapedArray(gamma.shape, w_dtype, named_shape=gamma.named_shape), # grad gamma - ) + assert q_dtype == kv_dtype + assert q_batch_shape == kv_batch_shape + assert q_num_head == kv_num_head + assert q_head_dim == kv_head_dim + assert nkv == 2 + # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen + assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype + + output_shape = q_aval.shape + output_dtype = q_dtype + softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen) + softmax_aux_dtype = q_dtype + + out_aval = q_aval.update(shape=output_shape, dtype=output_dtype) + softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype) + return out_aval, softmax_aux_aval @staticmethod - def lowering(ctx, grad_output, inv_var, x, gamma, *, epsilon): + def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training): """ - RMSNorm bwd lowering rules + Cross fused attention fwd lowering rules """ - _, _, x_aval, gamma_aval = ctx.avals_in - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - w_type = ir.RankedTensorType(gamma.type) - w_shape = w_type.shape - go_shape = ir.RankedTensorType(grad_output.type).shape - inv_var_shape = ir.RankedTensorType(inv_var.type).shape + q_aval, kv_aval, _, _, _ = ctx.avals_in + assert q_aval.dtype == kv_aval.dtype - hidden_size = reduce(operator.mul, w_shape) - # In Transformer, batch_size = batch x seqlen - batch_size = reduce(operator.mul, x_shape) // hidden_size + *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape + batch = reduce(operator.mul, batch_shape) + kv_max_seqlen = kv_aval.shape[-4] + operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed] + operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ - ir.RankedTensorType.get(x_shape, x_type.element_type), - ir.RankedTensorType.get(w_shape, w_type.element_type), + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out ] - operands = [grad_output, inv_var, x, gamma] - operand_shapes = [go_shape, inv_var_shape, x_shape, w_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - False, # RMSNorm doesn't support zero_centered_gamma - epsilon, - ) + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(q_aval.dtype), is_training) - out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False) + out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) return out + @staticmethod + def impl(q, kv, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training): + assert CrossFusedAttnFwdPrimitive.inner_primitive is not None -_rmsnorm_bwd_p = register_primitive(RmsNormBwdPrimitive) + q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask) + kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask) + output, softmax_aux = CrossFusedAttnFwdPrimitive.inner_primitive.bind( + q, + kv, + q_cu_seqlen, + kv_cu_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return output, softmax_aux -def rmsnorm_bwd(grad: jnp.ndarray, inv_var: jnp.ndarray, x: jnp.ndarray, gamma: jnp.ndarray, - epsilon: float): + @staticmethod + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + _check_valid_batch_dims(batch_dims) + assert CrossFusedAttnFwdPrimitive.outer_primitive is not None + q, kv, q_cu_seqlen, kv_cu_seqlen, seed = batched_args + q_bdim, *_ = batch_dims + + out_bdims = q_bdim, q_bdim + return CrossFusedAttnFwdPrimitive.outer_primitive.bind( + q, + kv, + q_cu_seqlen, + kv_cu_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims + + @staticmethod + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, mesh, arg_infos, + result_infos): + del attn_bias_type, attn_mask_type, scaling_factor + del dropout_probability, is_training, result_infos + q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) + kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4])) + return (out_sharding, softmax_aux_sharding) + + @staticmethod + def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, + mesh, arg_infos, result_infos): + del result_infos + q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) + kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4])) + seed_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + out_shardings = (out_sharding, softmax_aux_sharding) + impl = partial(CrossFusedAttnFwdPrimitive.impl, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(CrossFusedAttnFwdPrimitive) + + +def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_squeezed_mask: jnp.ndarray, + kv_squeezed_mask: jnp.ndarray, seed: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, dropout_probability: float, is_training: bool): """ - Wrapper for TE rmsnorm bwd + Wrapper for TE cross fused attention fwd + Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ - return _rmsnorm_bwd_p.bind(grad, inv_var, x, gamma, epsilon=epsilon) + checker = _FusedAttnRNGStateChecker() + seed = checker.check_seed(seed, dropout_probability, is_training) + + return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q, + kv, + q_squeezed_mask, + kv_squeezed_mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) -class QuantizePrimitive(BasePrimitive): +class CrossFusedAttnBwdPrimitive(BasePrimitive): """ - Quantize Primitive + Cross Fused Attention Backward Primitive """ - name = "te_quantize" + name = "te_cross_fused_attn_backward" multiple_results = True + impl_static_args = (6, 7, 8, 9, 10) + inner_primitive = None + outer_primitive = None @staticmethod - def abstract(inputs, amax, scale, scale_inv, *, out_dtype): + def abstract(q_aval, kv_aval, softmax_aux_aval, doutput_aval, q_cu_seqlen_aval, + kv_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): """ - te_quantize abstract + Cross fused attention bwd abstract """ - in_dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert in_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - - assert isinstance(out_dtype, TEDType) - out_dtype = te_dtype_to_jax_dtype(out_dtype) - - assert amax.dtype == jnp.float32 - assert scale.dtype == jnp.float32 - assert scale_inv.dtype == jnp.float32 + del attn_bias_type, attn_mask_type + del scaling_factor, dropout_probability, is_training + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) + softmax_aux_dtype = dtypes.canonicalize_dtype(softmax_aux_aval.dtype) + doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) + assert q_dtype == kv_dtype == softmax_aux_dtype == doutput_dtype + # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen + assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype - return (ShapedArray(inputs.shape, out_dtype, named_shape=inputs.named_shape), - ShapedArray((1,), amax.dtype, named_shape=amax.named_shape)) + dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype) + return dq_aval, dkv_aval @staticmethod - def lowering(ctx, inputs, amax, scale, scale_inv, *, out_dtype): + def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, is_training): """ - te_quantize lowering rules + Cross fused attention bwd lowering rules """ - in_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - - ir_in_type = ir.RankedTensorType(inputs.type) - ir_in_shape = ir_in_type.shape - - ir_out_dtype = te_dtype_to_ir_dtype(out_dtype) - ir_out_shape = ir_in_shape - - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_shape = ir_amax_type.shape - ir_amax_dtype = ir_amax_type.element_type + q_aval, kv_aval, _, _, _, _ = ctx.avals_in + assert q_aval.dtype == kv_aval.dtype - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape + *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape + batch = reduce(operator.mul, batch_shape) + kv_max_seqlen = kv_aval.shape[-4] + operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen] + operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ - ir.RankedTensorType.get(ir_out_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out ] - operands = [inputs, amax, scale, scale_inv] - operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_descriptor(in_aval.shape, - jax_dtype_to_te_dtype(in_aval.dtype), - out_dtype) + # the dropout elements are encoded in the forward auxiliary tensor + # so seed is not needed in backward + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(q_aval.dtype), is_training) - out = custom_caller(QuantizePrimitive.name, - args, - opaque, - False, - operand_output_aliases={1: 1}) + out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) return out + @staticmethod + def impl(q, kv, softmax_aux, doutput, q_squeezed_mask, kv_squeezed_mask, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, is_training): + assert CrossFusedAttnBwdPrimitive.inner_primitive is not None + + q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask) + kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask) + + dq, dkv = CrossFusedAttnBwdPrimitive.inner_primitive.bind( + q, + kv, + softmax_aux, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return dq, dkv + + @staticmethod + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + _check_valid_batch_dims(batch_dims) + assert CrossFusedAttnBwdPrimitive.outer_primitive is not None + q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen = batched_args + q_bdim, kv_bdim, *_ = batch_dims + + out_bdims = q_bdim, kv_bdim + return CrossFusedAttnBwdPrimitive.outer_primitive.bind( + q, + kv, + softmax_aux, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims + + @staticmethod + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, mesh, arg_infos, + result_infos): + del attn_bias_type, attn_mask_type, scaling_factor + del dropout_probability, is_training, result_infos + q_spec = get_padded_spec(arg_infos[0]) + kv_spec = get_padded_spec(arg_infos[1]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec)) + return (dq_sharding, dkv_sharding) + + @staticmethod + def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, + mesh, arg_infos, result_infos): + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + kv_spec = get_padded_spec(arg_infos[1]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dq_sharding, dkv_sharding) + + impl = partial(CrossFusedAttnBwdPrimitive.impl, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + + return mesh, impl, out_shardings, arg_shardings + -_quantize_p = register_primitive(QuantizePrimitive) +register_primitive(CrossFusedAttnBwdPrimitive) -def quantize(inputs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray]: +def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray, + doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray, + kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, scaling_factor: float, + dropout_probability: float, is_training: bool): """ - quantize wrapper - Return FP8 tensor + Wrapper for TE cross fused attention bwd + Return the gradients of cross fused attention with packed kv input """ - return _quantize_p.bind(inputs, amax, scale, scale_inv, out_dtype=out_dtype) + return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q, + kv, + softmax_aux, + doutput, + q_squeezed_mask, + kv_squeezed_mask, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) -class DequantizePrimitive(BasePrimitive): +class GatedGeluPrimitive(BasePrimitive): """ - Dequantize Primitive + Gated Gelu Froward Primitive """ - name = "te_dequantize" + name = "te_gated_gelu" multiple_results = False + inner_primitive = None + outer_primitive = None + impl_static_args = () @staticmethod - def abstract(inputs, amax, scale, scale_inv, *, fp8_dtype, out_dtype): + def abstract(x_aval): """ - te_dquantize abstract + gated_gelu abstract """ - in_dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert in_dtype == jnp.int8 - assert isinstance(fp8_dtype, TEDType) - - assert isinstance(out_dtype, TEDType) - out_dtype = te_dtype_to_jax_dtype(out_dtype) - assert out_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - - assert amax.dtype == jnp.float32 - assert scale.dtype == jnp.float32 - assert scale_inv.dtype == jnp.float32 + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + x_shape = x_aval.shape + assert x_shape[-2] == 2 # Assume x in (....., 2, hidden) + hidden_size = x_shape[-1] + batch_shapes = x_shape[:-2] + x_shape = x_aval.shape + out_aval = core.raise_to_shaped(x_aval) + out_shape = (batch_shapes) + (hidden_size,) + out_aval = out_aval.update(shape=out_shape, dtype=dtype) - return ShapedArray(inputs.shape, out_dtype, named_shape=inputs.named_shape) + return out_aval @staticmethod - def lowering(ctx, inputs, amax, scale, scale_inv, *, fp8_dtype, out_dtype): + def lowering(ctx, x): """ - te_dquantize lowering rules + gated_gelu lowering rules """ - in_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + (x_aval,) = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]] - assert in_aval.dtype == jnp.int8 - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 + out_types = [ + ir.RankedTensorType.get(out_shape, ir_x_type.element_type), + ] + operands = [x] + operand_shapes = [ir_x_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - ir_in_type = ir.RankedTensorType(inputs.type) - ir_in_shape = ir_in_type.shape + hidden_size = ir_x_shape[-1] + batch_size = reduce(operator.mul, ir_x_shape[:-2]) + in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype, + in_dtype) - ir_out_dtype = te_dtype_to_ir_dtype(out_dtype) - ir_out_shape = ir_in_shape + out = custom_caller(GatedGeluPrimitive.name, args, opaque, False) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_shape = ir_amax_type.shape + return [out] - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape + @staticmethod + def impl(x): + assert GatedGeluPrimitive.inner_primitive is not None + out = GatedGeluPrimitive.inner_primitive.bind(x) + return out - out_types = [ir.RankedTensorType.get(ir_out_shape, ir_out_dtype)] - operands = [inputs, amax, scale, scale_inv] - operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + @staticmethod + def batcher(batched_args, batch_dims): + """ + gated_gelu batcher + """ + _check_valid_batch_dims(batch_dims) + assert GatedGeluPrimitive.outer_primitive is not None + inputs, = batched_args + inputs_bdim, = batch_dims - opaque = transformer_engine_jax.pack_common_descriptor(in_aval.shape, fp8_dtype, out_dtype) + out_bdims = inputs_bdim + return GatedGeluPrimitive.outer_primitive.bind(inputs), out_bdims - out = custom_caller(DequantizePrimitive.name, args, opaque, False) + @staticmethod + def infer_sharding_from_operands(mesh, arg_infos, result_infos): + """ + gated_gelu infer_sharding_from_operands + """ + del result_infos # Unused. + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) + return out_sharding - return [out] + @staticmethod + def partition(mesh, arg_infos, result_infos): + """ + gated_gelu partitioning + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) + impl = GatedGeluPrimitive.impl + return mesh, impl, out_sharding, arg_shardings -_dequantize_p = register_primitive(DequantizePrimitive) +register_primitive(GatedGeluPrimitive) -def dequantize(inputs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - fp8_dtype: TEDType, out_dtype: TEDType) -> jnp.ndarray: +def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray: """ - dequantize wrapper - Return FP16/BF16/FP32 tensor + gated gelu wrapper + Return FP8(geglu(inputs)) + Assume inputs has two dimensions shape and the memory layout is (N, 2, H) """ - return _dequantize_p.bind(inputs, - amax, - scale, - scale_inv, - fp8_dtype=fp8_dtype, - out_dtype=out_dtype) + return GatedGeluPrimitive.outer_primitive.bind(inputs) -class SoftmaxPrimitive(BasePrimitive): +class DgatedGeluPrimitive(BasePrimitive): """ - Softmax Primitive + Dgated Gelu Primitive """ - max_k_seqlen_supported = 4096 - - @staticmethod - def get_batch_per_block(k_seqlen: int) -> int: - """Get batch per CTA in Softmax kernels""" - threads_per_warp = 32 - threads_per_block = 128 # Depends on the kernel implmentation - - pow2 = 1 << (k_seqlen - 1).bit_length() - warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp - batches_per_warp = 2 if pow2 <= 128 else 1 - warps_per_block = threads_per_block // warp_size - batches_per_block = warps_per_block * batches_per_warp - return batches_per_block - - @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: - """Check Softmax kernel availability based on size""" - raise NotImplementedError + name = "te_dgated_gelu" + multiple_results = False + inner_primitive = None + outer_primitive = None + impl_static_args = () @staticmethod - def softmax_backward_abstract(grad_outputs, softmax_outputs, scale_factor=None): # pylint: disable=unused-argument + def abstract(dz_aval, x_aval): """ - MLIR abstract + dgated_gelu abstract """ - grad_outputs_dtype = dtypes.canonicalize_dtype(grad_outputs.dtype) - softmax_outputs_dtype = dtypes.canonicalize_dtype(softmax_outputs.dtype) - assert grad_outputs_dtype == softmax_outputs_dtype - assert grad_outputs_dtype in [jnp.float16, jnp.bfloat16] - assert softmax_outputs_dtype in [jnp.float16, jnp.bfloat16] + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dtype + for axis in range(len(dz_aval.shape) - 1): + assert dz_aval.shape[axis] == x_aval.shape[axis] - assert grad_outputs.shape == softmax_outputs.shape + assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) - return ShapedArray(softmax_outputs.shape, - softmax_outputs_dtype, - named_shape=softmax_outputs.named_shape) + i_hidden_size = dz_aval.shape[-1] + g_hidden_size = x_aval.shape[-1] + assert i_hidden_size == g_hidden_size + out_aval = core.raise_to_shaped(x_aval) + return out_aval @staticmethod - def softmax_backward_lowering(name, ctx, grad_outputs, softmax_outputs, scale_factor): + def lowering(ctx, dz, x): """ - MLIR abstract + dgated_gelu lowering rules """ - grad_outputs_aval, _ = ctx.avals_in - - grad_outputs_type = ir.RankedTensorType(grad_outputs.type) - grad_outputs_shape = grad_outputs_type.shape - - batch = grad_outputs_shape[0] - pad_batch = batch # unused - heads = grad_outputs_shape[1] - q_seqlen = grad_outputs_shape[2] - k_seqlen = grad_outputs_shape[3] + in_aval, gi_aval = ctx.avals_in + assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert gi_aval.dtype == in_aval.dtype + ir_in_type = ir.RankedTensorType(dz.type) + ir_in_shape = ir_in_type.shape + gi_type = ir.RankedTensorType(x.type) + gi_shape = gi_type.shape + for axis in range(len(ir_in_shape) - 1): + assert ir_in_shape[axis] == gi_shape[axis] - softmax_outputs_type = ir.RankedTensorType(softmax_outputs.type) - softmax_outputs_shape = softmax_outputs_type.shape + ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) + i_hidden_size = ir_in_shape[-1] + g_hidden_size = gi_shape[-1] + assert i_hidden_size == g_hidden_size + out_dtype = ir_in_type.element_type + out_shape = gi_shape out_types = [ - ir.RankedTensorType.get(softmax_outputs_shape, softmax_outputs_type.element_type) + ir.RankedTensorType.get(out_shape, out_dtype), ] - operands = [grad_outputs, softmax_outputs] - operand_shapes = [grad_outputs_shape, softmax_outputs_shape] + operands = [dz, x] + operand_shapes = [ir_in_shape, gi_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, pad_batch, heads, q_seqlen, k_seqlen, - jax_dtype_to_te_dtype(grad_outputs_aval.dtype), scale_factor) + in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), + in_dtype, in_dtype) - out = custom_caller(name, args, opaque, False) + out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False) return [out] - -class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): - """ - Scaled Softmax Fwd Primitive - """ - name = "te_scaled_softmax_forward" - multiple_results = False - @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: - """Check Softmax kernel availability based on size""" - attn_batches = batch * heads - - dtype = dtypes.canonicalize_dtype(dtype) - if (dtype in [jnp.float16, jnp.bfloat16] - and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - # k_seqlen must be 16 ~ 4096 - and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 - and attn_batches % 4 == 0 # batch * heads must be divisor of 4 - ): - if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: - batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) - return q_seqlen % batch_per_block == 0 - return False + def impl(dz, x): + """ + dgated_gelu implementation + """ + assert DgatedGeluPrimitive.inner_primitive is not None + dx = DgatedGeluPrimitive.inner_primitive.bind(dz, x) + return dx @staticmethod - def abstract(inputs, *, scale_factor): # pylint: disable=unused-argument + def batcher(batched_args, batch_dims): """ - te_scaled_softmax_forward abstract + dgated_gelu batcher """ - shape_rank = 4 # batch, heads, q_seqlen and k_seqlen - - i_dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert i_dtype in [jnp.float16, jnp.bfloat16] - i_shape = inputs.shape - assert len(i_shape) == shape_rank - q_seqlen = i_shape[2] - k_seqlen = i_shape[3] - assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - assert q_seqlen > 1 + _check_valid_batch_dims(batch_dims) + assert DgatedGeluPrimitive.outer_primitive is not None + dz, x = batched_args + _, x_bdim = batch_dims - return ShapedArray(inputs.shape, i_dtype, named_shape=inputs.named_shape) + out_bdims = x_bdim + return DgatedGeluPrimitive.outer_primitive.bind(dz, x), out_bdims @staticmethod - def lowering(ctx, inputs, *, scale_factor): + def infer_sharding_from_operands(mesh, arg_infos, result_infos): """ - te_scaled_softmax_forward lowering rules + dgated_gelu infer_sharding_from_operands """ - shape_rank = 4 # batch, heads, q_seqlen and k_seqlen + del result_infos # Unused. + gelu_out_spec = get_padded_spec(arg_infos[1]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec)) + return dx_sharding - i_aval, = ctx.avals_in - i_type = ir.RankedTensorType(inputs.type) - i_shape = i_type.shape - assert len(i_shape) == shape_rank - batch = i_shape[0] - pad_batch = batch - heads = i_shape[1] - q_seqlen = i_shape[2] - k_seqlen = i_shape[3] + @staticmethod + def partition(mesh, arg_infos, result_infos): + """ + dgated_gelu partition + """ + del result_infos + dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = dx_sharding + impl = DgatedGeluPrimitive.impl + return mesh, impl, out_shardings, arg_shardings - out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] - operands = [inputs] - operand_shapes = [i_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(i_aval.dtype), - scale_factor) +register_primitive(DgatedGeluPrimitive) - out = custom_caller(ScaledSoftmaxFwdPrimitive.name, args, opaque, False) - return [out] +def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray: + """ + dgated_gelu fusion wrapper + Return dgeglu(inputs) + """ + return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs) -_scaled_softmax_fwd_p = register_primitive(ScaledSoftmaxFwdPrimitive) +def _normalize_axis_boundary(axis, ndim): + return axis if axis >= 0 else ndim + axis -def scaled_softmax_fwd(inputs: jnp.ndarray, scale_factor: float) -> jnp.ndarray: +def _multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): """ - scaled_softmax_forward wrapper - Return FP16/BF16 tensor + te_cast_transpose_p multi-dims transpose + + static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be + involved into transpose, -1 means all axes involve into transpose. + transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for + transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary + + examples: + X in shape (dim0, dim1, dim2, dim3, dim4) + + static_axis_boundary == -1, transpose_axis_boundary == 2 + Xt = (dim2, dim3, dim4, dim0, dim1) + + static_axis_boundary == 0, transpose_axis_boundary == 2 + Xt = (dim0, dim2, dim3, dim4, dim1) + + static_axis_boundary == 0, transpose_axis_boundary == 3 + Xt = (dim0, dim3, dim4, dim1. dim2) """ - return _scaled_softmax_fwd_p.bind(inputs, scale_factor=scale_factor) + if static_axis_boundary < 0: + static_axis_boundary = -1 # means no static axes + assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. + transpose_start_idx = static_axis_boundary + 1 + transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, len(shape)) + assert transpose_start_idx < transpose_axis_boundary + return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:], + *shape[transpose_start_idx:transpose_axis_boundary]) -class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): +class CastTransposePrimitive(BasePrimitive): """ - Scaled Softmax Bwd Primitive + Cast Transpose Primitive """ - name = "te_scaled_softmax_backward" - multiple_results = False + name = "te_cast_transpose" + multiple_results = True + impl_static_args = (4, 5, 6) + inner_primitive = None + outer_primitive = None @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: - """Check Softmax kernel availability based on size""" - return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen, - dtype) + def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + te_cast_transpose_p abstract + """ + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + + transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, + transpose_axis_boundary) + + casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype) + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + + return casted_x_aval, casted_xt_aval, updated_amax_aval @staticmethod - def abstract(grad_outputs, softmax_outputs, *, scale_factor): + def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): """ - te_scaled_softmax_backward abstract + te_cast_transpose_p lowering rules """ - return SoftmaxPrimitive.softmax_backward_abstract(grad_outputs, softmax_outputs, - scale_factor) + x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + if static_axis_boundary >= 0: + for i in range(static_axis_boundary + 1): + assert ir_x_shape[i] == 1 + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary, + transpose_axis_boundary) + + out_types = [ + ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), + reduce(operator.mul, ir_x_shape[transpose_axis_boundary:])) + opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype)) + + out = custom_caller(CastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={1: 2}) + + return out @staticmethod - def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor): + def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary): """ - te_scaled_softmax_backward lowering rules + te_cast_transpose implementation """ - out = SoftmaxPrimitive.softmax_backward_lowering(ScaledSoftmaxBwdPrimitive.name, ctx, - grad_outputs, softmax_outputs, - scale_factor) + assert CastTransposePrimitive.inner_primitive is not None + casted_x, casted_transposed_x, updated_amax = \ + CastTransposePrimitive.inner_primitive.bind( + x, amax, scale, scale_inv, out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + return casted_x, casted_transposed_x, updated_amax - return out # out is iterable already + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): + _check_valid_batch_dims(batch_dims) + assert CastTransposePrimitive.outer_primitive is not None + assert static_axis_boundary < 0 + x, amax, scale, scale_inv = batched_args + x_bdim, amax_bdim, *_ = batch_dims -_scaled_softmax_bwd_p = register_primitive(ScaledSoftmaxBwdPrimitive) + # Minus batch dim. + transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) + transpose_axis_boundary += 1 # Plus batch dim + out_bdims = x_bdim, x_bdim, amax_bdim + return CastTransposePrimitive.outer_primitive.bind( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=x_bdim, + transpose_axis_boundary=transpose_axis_boundary), out_bdims -def scaled_softmax_bwd(grad_outputs: jnp.ndarray, softmax_outputs: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: + @staticmethod + def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, + arg_infos, result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[0]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) + + @staticmethod + def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, + result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) + + def sharded_impl(x, amax, scale, scale_inv): + local_cx, local_cxt, local_updated_amax = \ + CastTransposePrimitive.impl(x, amax, scale, scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) + + return local_cx, local_cxt, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(CastTransposePrimitive) + + +def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, + out_dtype: jnp.dtype, static_axis_boundary: int, + transpose_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ - scaled_softmax_backward wrapper - Return FP16/BF16 tensor + cast transpose wrapper + Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` """ - return _scaled_softmax_bwd_p.bind(grad_outputs, softmax_outputs, scale_factor=scale_factor) + return CastTransposePrimitive.outer_primitive.bind( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) -class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): +class TransposePrimitive(BasePrimitive): """ - Scaled Masked Softmax Fwd Primitive + Transpose Primitive """ - name = "te_scaled_masked_softmax_forward" + name = "te_transpose" multiple_results = False + impl_static_args = (1, 2) + inner_primitive = None + outer_primitive = None @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: - """Check Softmax kernel availability based on size""" - attn_batches = batch * heads - - dtype = dtypes.canonicalize_dtype(dtype) - if (dtype in [jnp.float16, jnp.bfloat16] - and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - # k_seqlen must be 16 ~ 4096 - and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 - and attn_batches % 4 == 0 # batch * heads must be divisor of 4 - ): - if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: - batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) - return q_seqlen % batch_per_block == 0 - return False - - @staticmethod - def abstract(inputs, mask, *, scale_factor): # pylint: disable=unused-argument + def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary): """ - te_scaled_masked_softmax_forward abstract + _transpose abstract """ - shape_rank = 4 # batch, heads, q_seqlen and k_seqlen - - i_dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert i_dtype in [jnp.float16, jnp.bfloat16] - i_shape = inputs.shape - assert len(i_shape) == shape_rank - batch = i_shape[0] - q_seqlen = i_shape[2] - k_seqlen = i_shape[3] - assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - assert q_seqlen > 1 - - mask_dtype = dtypes.canonicalize_dtype(mask.dtype) - assert mask_dtype in [ - jnp.uint8, - ] - mask_shape = mask.shape - assert len(mask_shape) == shape_rank - pad_batch = mask_shape[0] - assert pad_batch in (1, batch) # 1 means broadcast - assert mask_shape[1] == 1 # 1 means broadcast - assert mask_shape[2] == q_seqlen - assert mask_shape[3] == k_seqlen + transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, + transpose_axis_boundary) + xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype) - return ShapedArray(inputs.shape, i_dtype, named_shape=inputs.named_shape) + return xt_aval @staticmethod - def lowering(ctx, inputs, mask, *, scale_factor): + def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary): """ - te_scaled_masked_softmax_forward lowering rules + _transpose cuda lowering """ - shape_rank = 4 # batch, heads, q_seqlen and k_seqlen - i_aval, _ = ctx.avals_in - i_type = ir.RankedTensorType(inputs.type) - i_shape = i_type.shape - assert len(i_shape) == shape_rank - batch = i_shape[0] - heads = i_shape[1] - q_seqlen = i_shape[2] - k_seqlen = i_shape[3] + x_aval = ctx.avals_in[0] + assert x_aval.dtype in [ + jnp.float32, jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2 + ] - mask_type = ir.RankedTensorType(mask.type) - mask_shape = mask_type.shape - assert len(mask_shape) == shape_rank - pad_batch = mask_shape[0] + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype) + if static_axis_boundary >= 0: + for i in range(static_axis_boundary + 1): + assert ir_x_shape[i] == 1 - out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] - operands = [inputs, mask] - operand_shapes = [i_shape, mask_shape] + transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary, + transpose_axis_boundary) + + out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)] + operands = [x] + operand_shapes = [ir_x_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(i_aval.dtype), - scale_factor) + te_dtype = jax_dtype_to_te_dtype(x_aval.dtype) + contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), + reduce(operator.mul, ir_x_shape[transpose_axis_boundary:])) + opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, te_dtype, + te_dtype) - out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False) + out = custom_caller(TransposePrimitive.name, args, opaque, False) return [out] + @staticmethod + def impl(x, static_axis_boundary, transpose_axis_boundary): + """ + tcast_transpose implementation + """ + assert TransposePrimitive.inner_primitive is not None + transposed_x = \ + TransposePrimitive.inner_primitive.bind(x, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + return transposed_x -_scaled_masked_softmax_fwd_p = register_primitive(ScaledMaskedSoftmaxFwdPrimitive) - + @staticmethod + def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary): + _check_valid_batch_dims(batch_dims) + assert TransposePrimitive.outer_primitive is not None + assert static_axis_boundary < 0 -def scaled_masked_softmax_fwd(inputs: jnp.ndarray, mask: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: - """ - scaled_masked_softmax_forward wrapper - Return FP16/BF16 tensor - """ - return _scaled_masked_softmax_fwd_p.bind(inputs, mask, scale_factor=scale_factor) + x, = batched_args + x_bdim, = batch_dims + # Minus batch dim. + transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) + transpose_axis_boundary += 1 # Plus batch dim -class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): - """ - Scaled Masked Softmax Bwd Primitive - """ - name = "te_scaled_masked_softmax_backward" - multiple_results = False + out_bdims = x_bdim + return TransposePrimitive.outer_primitive.bind( + x, static_axis_boundary=x_bdim, + transpose_axis_boundary=transpose_axis_boundary), out_bdims @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: - """Check Softmax kernel availability based on size""" - return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen, - dtype) + def infer_sharding_from_operands(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, + result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + return transposed_x_sharding @staticmethod - def abstract(grad_outputs, softmax_outputs, *, scale_factor): - """ - te_scaled_masked_softmax_backward abstract - """ - return SoftmaxPrimitive.softmax_backward_abstract(grad_outputs, softmax_outputs, - scale_factor) + def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = transposed_x_sharding - @staticmethod - def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor): - """ - te_scaled_masked_softmax_backward lowering rules - """ - out = SoftmaxPrimitive.softmax_backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name, ctx, - grad_outputs, softmax_outputs, - scale_factor) + impl = partial(TransposePrimitive.impl, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) - return out # out is iterable already + return mesh, impl, out_shardings, arg_shardings -_scaled_masked_softmax_bwd_p = register_primitive(ScaledMaskedSoftmaxBwdPrimitive) +register_primitive(TransposePrimitive) -def scaled_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_outputs: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: +def transpose(x: jnp.ndarray, static_axis_boundary: int, + transpose_axis_boundary: int) -> jnp.ndarray: """ - scaled_masked_softmax_backward wrapper - Return FP16/BF16 tensor + transpose wrapper """ - return _scaled_masked_softmax_bwd_p.bind(grad_outputs, - softmax_outputs, - scale_factor=scale_factor) + return TransposePrimitive.outer_primitive.bind(x, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) -class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): +class LayerNormFwdFp8Primitive(BasePrimitive): """ - Scaled Upper Triang Masked Softmax Fwd Primitive + Layer Normalization Forward FP8 Primitive """ - name = "te_scaled_upper_triang_masked_softmax_forward" - multiple_results = False - - @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: - """Check Softmax kernel availability based on size""" - attn_batches = batch * heads - - dtype = dtypes.canonicalize_dtype(dtype) - if (dtype in [jnp.float16, jnp.bfloat16] - and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - # k_seqlen must be 16 ~ 4096 - and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 - and attn_batches % 4 == 0 # batch * heads must be divisor of 4 - ): - if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: - batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) - return attn_batches % batch_per_block == 0 - return False + name = "te_layernorm_forward_fp8" + multiple_results = True + impl_static_args = (6, 7, 8) # out_type, zero_centered_gamma, epsilon + inner_primitive = None + outer_primitive = None @staticmethod - def abstract(inputs, *, scale_factor): # pylint: disable=unused-argument + def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, + zero_centered_gamma, epsilon): """ - te_scaled_upper_triang_masked_softmax_forward abstract + LayerNorm fwd (fp8 out) abstract """ - shape_rank = 4 # batch, heads, q_seqlen and k_seqlen + del zero_centered_gamma, epsilon + x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) - i_dtype = dtypes.canonicalize_dtype(inputs.dtype) - assert i_dtype in [jnp.float16, jnp.bfloat16] - i_shape = inputs.shape - assert len(i_shape) == shape_rank - q_seqlen = i_shape[2] - k_seqlen = i_shape[3] - assert q_seqlen == k_seqlen - assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - assert q_seqlen > 1 + assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 - return ShapedArray(inputs.shape, i_dtype, named_shape=inputs.named_shape) + mu_rsigama_dtype = jnp.float32 + + assert gamma_aval.size == beta_aval.size + + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + + return out_aval, mu_aval, rsigma_aval, updated_amax_aval @staticmethod - def lowering(ctx, inputs, *, scale_factor): + def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma, + epsilon): """ - te_scaled_upper_triang_masked_softmax_forward lowering rules + LayerNorm fwd (fp8 out) lowering rules """ - shape_rank = 4 # batch, heads, q_seqlen and k_seqlen - - i_aval, = ctx.avals_in - i_type = ir.RankedTensorType(inputs.type) - i_shape = i_type.shape - assert len(i_shape) == shape_rank - batch = i_shape[0] - pad_batch = batch - heads = i_shape[1] - q_seqlen = i_shape[2] - k_seqlen = i_shape[3] + x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] - operands = [inputs] - operand_shapes = [i_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + # Currently only support casting to E4M3 only in C side. + assert out_dtype == jnp.float8_e4m3fn - opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(i_aval.dtype), - scale_factor) + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert gamma_aval.dtype == beta_aval.dtype + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 - out = custom_caller(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, args, opaque, False) + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + b_type = ir.RankedTensorType(beta.type) + b_shape = b_type.shape - return [out] + assert g_type == b_type + assert g_shape == b_shape -_scaled_upper_triang_masked_softmax_fwd_p = \ - register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_mu_dtype = ir.F32Type.get() + ir_rsigma_dtype = ir.F32Type.get() + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + out_shape = x_shape + hidden_size = reduce(operator.mul, g_shape) + batch_shape = out_shape[:-1] + batch_size = reduce(operator.mul, x_shape) // hidden_size -def scaled_upper_triang_masked_softmax_fwd(inputs: jnp.ndarray, scale_factor: float) -> jnp.ndarray: - """ - scaled_upper_triang_masked_softmax_forward wrapper - Return FP16/BF16 tensor - """ - return _scaled_upper_triang_masked_softmax_fwd_p.bind(inputs, scale_factor=scale_factor) + out_types = [ + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(batch_shape, ir_mu_dtype), + ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, gamma, beta, amax, scale, scale_inv] + operand_shapes = [ + x_shape, g_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + zero_centered_gamma, + epsilon, + ) -class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): - """ - Scaled Upper Triang Masked Softmax Bwd Primitive - """ - name = "te_scaled_upper_triang_masked_softmax_backward" - multiple_results = False + out = custom_caller(LayerNormFwdFp8Primitive.name, + args, + opaque, + False, + operand_output_aliases={3: 3}) - @staticmethod - def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, - dtype: jnp.dtype) -> bool: - """Check Softmax kernel availability based on size""" - return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available( - batch, heads, q_seqlen, k_seqlen, dtype) + return out @staticmethod - def abstract(grad_outputs, softmax_outputs, *, scale_factor): + def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, epsilon): """ - te_scaled_upper_triang_masked_softmax_backward abstract + to describe implementation """ - return SoftmaxPrimitive.softmax_backward_abstract(grad_outputs, softmax_outputs, - scale_factor) + assert LayerNormFwdFp8Primitive.inner_primitive is not None + out, mu, rsigma, updated_amax = LayerNormFwdFp8Primitive.inner_primitive.bind( + x, + gamma, + beta, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) + return out, mu, rsigma, updated_amax @staticmethod - def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor): + def batcher(batched_args, batch_dims, *, out_dtype, zero_centered_gamma, epsilon): """ - te_scaled_upper_triang_masked_softmax_backward lowering rules + to describe batch rules for vmap """ - out = SoftmaxPrimitive.softmax_backward_lowering( - ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ctx, grad_outputs, softmax_outputs, - scale_factor) + _check_valid_batch_dims(batch_dims) + assert LayerNormFwdFp8Primitive.outer_primitive is not None + x, gamma, beta, amax, scale, scale_inv = batched_args + x_bdim, _, _, amax_bdim, _, _ = batch_dims + + out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim + return LayerNormFwdFp8Primitive.outer_primitive.bind( + x, + gamma, + beta, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon), out_bdims + + @staticmethod + def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, + result_infos): + del out_dtype, zero_centered_gamma, epsilon, result_infos + x_spec = get_padded_spec(arg_infos[0]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance.") + + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3]))) + return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding) + + @staticmethod + def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + out_sharding = x_sharding + mu_sharding = rsigma_sharding = NamedSharding( + mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3]))) + fp8_meta_sharding = amax_sharding + arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3 + out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding) - return out # out is iterable already + def sharded_impl(x, gamma, beta, amax, scale, scale_inv): + local_x, local_mu, local_rsigma, local_amax = \ + LayerNormFwdFp8Primitive.impl(x, gamma, beta, amax, scale, scale_inv, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) -_scaled_upper_triang_masked_softmax_bwd_p = \ - register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) + return local_x, local_mu, local_rsigma, global_updated_amax + return mesh, sharded_impl, out_shardings, arg_shardings -def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_outputs: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: + +register_primitive(LayerNormFwdFp8Primitive) + + +def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray, + scale: jnp.ndarray, scale_inv: jnp.ndarray, out_dtype: jnp.dtype, + zero_centered_gamma: bool, epsilon: float): """ - scaled_upper_triang_masked_softmax_backward wrapper - Return FP16/BF16 tensor + Wrapper for TE layernorm fwd (fp8 out) """ - return _scaled_upper_triang_masked_softmax_bwd_p.bind(grad_outputs, - softmax_outputs, - scale_factor=scale_factor) + return LayerNormFwdFp8Primitive.outer_primitive.bind(x, + gamma, + beta, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) -@dataclass(frozen=True) -class _FusedAttnRNGStateChecker: +class RmsNormFwdFp8Primitive(BasePrimitive): """ - Checker for guarding the fused attention rng state. - The fused attention backend requires a 64 bits seed and a 64 bits offset. - However, JAX doesn't enable 64 bits by default, - so we have to emulate seed as two 32 bits array. - The offset calculation is maintained in the backend. + RMS Normalization Forward FP8 Primitive """ - rng_state_dtype: jnp.dtype = jnp.uint32 - # (seed,) with internal dtype int64 - seed_size: int = 2 - # (seed, offset) with internal dtype int64 - rng_state_size: int = 2 * 2 + name = "te_rmsnorm_forward_fp8" + multiple_results = True + impl_static_args = (5, 6) # out_dtype, epsilon + inner_primitive = None + outer_primitive = None - def check_seed(self, seed, dropout_probability, is_training): + @staticmethod + def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon): """ - Check the seed and convert the data type of seed if possible. + RMSNorm fwd (fp8 out) abstract """ - # Jax can't bind None, create a dummy tensor for None - if seed is None: - dropout_enabled = dropout_probability > 0 and is_training - assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled." - seed = jnp.zeros(2, dtype=self.rng_state_dtype) + del epsilon + x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) - if seed.dtype != self.rng_state_dtype: - warnings.warn( - f"Requested {seed.dtype=} is not available, and will be " - f"casted to dtype {self.rng_state_dtype}. " - f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.") - seed = seed.astype(self.rng_state_dtype) + assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 - assert seed.dtype == self.rng_state_dtype - # Backend takes an int64_t seed, so only the first two u32 elements are taken - assert seed.size >= self.seed_size + hidden_size = gamma_aval.size + assert x_aval.size % hidden_size == 0 - return seed + rsigama_dtype = jnp.float32 + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) + amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) -class SelfFusedAttnFwdPrimitive(BasePrimitive): - """ - Self Fused Attention Forward Primitive - """ - name = "te_self_fused_attn_forward" - multiple_results = True + return out_aval, rsigma_aval, amax_aval @staticmethod - def abstract( - qkv, - bias, - cu_seqlen, # pylint: disable=unused-argument - seed, # pylint: disable=unused-argument - *, - attn_bias_type, # pylint: disable=unused-argument - attn_mask_type, # pylint: disable=unused-argument - scaling_factor, # pylint: disable=unused-argument - dropout_probability, # pylint: disable=unused-argument - is_training # pylint: disable=unused-argument - ): + def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): """ - Self fused attention fwd abstract + RMSNorm fwd (fp8 out) lowering rules """ - qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype) - batch, max_seqlen, nqkv, num_head, head_dim = qkv.shape - assert nqkv == 3 - assert qkv.dtype == bias.dtype - output_shape = (batch, max_seqlen, num_head, head_dim) - output_dtype = qkv_dtype + # Currently only support casting to E4M3 only in C side. + assert out_dtype == jnp.float8_e4m3fn - backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, - attn_mask_type, dropout_probability, max_seqlen, max_seqlen, - head_dim).get_fused_attn_backend() + x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen) - softmax_dtype = qkv_dtype - elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_aux_shape = (batch, num_head, max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) - else: - raise ValueError(f'Not supported {backend=}') + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 - checker = _FusedAttnRNGStateChecker() - seed_dtype = dtypes.canonicalize_dtype(seed.dtype) - assert seed_dtype == checker.rng_state_dtype - rng_state_shape = (checker.rng_state_size,) - rng_state_dtype = seed_dtype + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_rsigma_dtype = ir.F32Type.get() + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + out_shape = x_shape + hidden_size = reduce(operator.mul, g_shape) + batch_shape = out_shape[:-1] + batch_size = reduce(operator.mul, x_shape) // hidden_size - return ( - ShapedArray(output_shape, output_dtype, named_shape=qkv.named_shape), # output - ShapedArray(softmax_aux_shape, softmax_dtype, - named_shape=qkv.named_shape), # softmax_aux - ShapedArray(rng_state_shape, rng_state_dtype, - named_shape=seed.named_shape), # rng_state + out_types = [ + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, gamma, amax, scale, scale_inv] + operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + False, # RMSNorm doesn't support zero_centered_gamma + epsilon, ) + out = custom_caller(RmsNormFwdFp8Primitive.name, + args, + opaque, + False, + operand_output_aliases={2: 2}) + + return out + @staticmethod - def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon): """ - Self fused attention fwd lowering rules + to describe implementation """ - qkv_aval, _, _, _ = ctx.avals_in + assert RmsNormFwdFp8Primitive.inner_primitive is not None + out, rsigma, amax = RmsNormFwdFp8Primitive.inner_primitive.bind(x, + gamma, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + epsilon=epsilon) + return out, rsigma, amax - batch, max_seqlen, _, num_head, head_dim = qkv_aval.shape + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, epsilon): + """ + to describe batch rules for vmap + """ + _check_valid_batch_dims(batch_dims) + assert RmsNormFwdFp8Primitive.outer_primitive is not None + x, gamma, amax, scale, scale_inv = batched_args + x_bdim, _, amax_bdim, _, _ = batch_dims + out_bdims = x_bdim, x_bdim, amax_bdim + return RmsNormFwdFp8Primitive.outer_primitive.bind(x, + gamma, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + epsilon=epsilon), out_bdims - operands = [qkv, bias, cu_seqlen, seed] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] + @staticmethod + def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos): + del out_dtype, epsilon, result_infos + x_spec = get_padded_spec(arg_infos[0]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + return (out_sharding, rsigma_sharding, amax_sharding) - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, - attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) + @staticmethod + def partition(out_dtype, epsilon, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + if x_spec[-1] is not None: + warnings.warn( + f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \ + f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ + f"and hurt performance." + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) + g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + out_sharding = x_sharding + rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + fp8_meta_sharding = amax_sharding + arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3 + out_shardings = (out_sharding, rsigma_sharding, amax_sharding) - out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) + def sharded_impl(x, gamma, amax, scale, scale_inv): + local_x, local_rsigma, local_amax= \ + RmsNormFwdFp8Primitive.impl(x, gamma, amax, scale, scale_inv, + out_dtype=out_dtype, epsilon=epsilon) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) - return out + return local_x, local_rsigma, global_updated_amax + return mesh, sharded_impl, out_shardings, arg_shardings -_self_fused_attn_fwd_p = register_primitive(SelfFusedAttnFwdPrimitive) +register_primitive(RmsNormFwdFp8Primitive) -def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): + +def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, + scale_inv: jnp.ndarray, out_dtype: jnp.dtype, epsilon: float): """ - Wrapper for TE self fused attention fwd - Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 + Wrapper for TE rmsnorm fwd (fp8 out) """ - checker = _FusedAttnRNGStateChecker() - seed = checker.check_seed(seed, dropout_probability, is_training) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - assert bias is None - bias = jnp.zeros(0, dtype=qkv.dtype) - return _self_fused_attn_fwd_p.bind(qkv, - bias, - cu_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + return RmsNormFwdFp8Primitive.outer_primitive.bind(x, + gamma, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + epsilon=epsilon) -class SelfFusedAttnBwdPrimitive(BasePrimitive): +class GatedGeluFp8Primitive(BasePrimitive): """ - Self Fused Attention Backward Primitive + Gated Gelu FP8 Primitive """ - name = "te_self_fused_attn_backward" + name = "te_gated_gelu_fp8" multiple_results = True + impl_static_args = (4,) #out_dtype + inner_primitive = None + outer_primitive = None @staticmethod - def abstract( - qkv, - softmax_aux, # pylint: disable=unused-argument - rng_state, # pylint: disable=unused-argument - output, # pylint: disable=unused-argument - doutput, - cu_seqlen, # pylint: disable=unused-argument - *, - attn_bias_type, # pylint: disable=unused-argument - attn_mask_type, # pylint: disable=unused-argument - scaling_factor, # pylint: disable=unused-argument - dropout_probability, # pylint: disable=unused-argument - is_training # pylint: disable=unused-argument - ): + def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): """ - Self fused attention bwd abstract + te_gated_gelu_p abstract """ - qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype) - assert qkv.dtype == doutput.dtype - - _, seqlen, _, num_head, _ = qkv.shape + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + # Currently only support casting to E4M3 only in C side. + assert out_dtype == jnp.float8_e4m3fn + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_shape = (0,) - else: - bias_shape = (1, num_head, seqlen, seqlen) - bias_dtype = qkv_dtype + assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) + hidden_size = x_aval.shape[-1] + batch_shape = x_aval.shape[:-2] + out_shape = (batch_shape) + (hidden_size,) + out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - return ( - ShapedArray(qkv.shape, qkv_dtype, named_shape=qkv.named_shape), # dqkv - ShapedArray(bias_shape, bias_dtype, named_shape=qkv.named_shape)) + return out_aval, updated_amax_aval @staticmethod - def lowering(ctx, qkv, softmax_aux, rng_state, output, doutput, cu_seqlen, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): + def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): """ - Self fused attention bwd lowering rules + te_gated_gelu_p lowering rules """ - qkv_aval, _, _, _, _, _ = ctx.avals_in - - batch, max_seqlen, _, num_head, head_dim = qkv_aval.shape + x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape - operands = [qkv, softmax_aux, rng_state, output, doutput, cu_seqlen] - operand_shapes = map(lambda x: x.type.shape, operands) + hidden_size = ir_x_shape[-1] + batch_shape = ir_x_shape[:-2] + batch_size = reduce(operator.mul, batch_shape) + out_shape = batch_shape + [hidden_size] out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ] - + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, - attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) + opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]), + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype)) - out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(GatedGeluFp8Primitive.name, + args, + opaque, + False, + operand_output_aliases={1: 1}) return out - -_self_fused_attn_bwd_p = register_primitive(SelfFusedAttnBwdPrimitive) - - -def self_fused_attn_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, - output: jnp.ndarray, doutput: jnp.ndarray, cu_seqlen: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): - """ - Wrapper for TE self fused attention bwd - Return the gradients of self fused attention with packed qkv input - """ - return _self_fused_attn_bwd_p.bind(qkv, - softmax_aux, - rng_state, - output, - doutput, - cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -class CrossFusedAttnFwdPrimitive(BasePrimitive): - """ - Cross Fused Attention Forward Primitive - """ - name = "te_cross_fused_attn_forward" - multiple_results = True - @staticmethod - def abstract( - q, - kv, - q_cu_seqlen, - kv_cu_seqlen, - seed, # pylint: disable=unused-argument - *, - attn_bias_type, # pylint: disable=unused-argument - attn_mask_type, # pylint: disable=unused-argument - scaling_factor, # pylint: disable=unused-argument - dropout_probability, # pylint: disable=unused-argument - is_training # pylint: disable=unused-argument - ): + def impl(x, amax, scale, scale_inv, out_dtype): """ - Cross fused attention fwd abstract + to describe implementation """ - q_dtype = dtypes.canonicalize_dtype(q.dtype) - batch_q, q_max_seqlen, num_head_q, head_dim_q = q.shape - kv_dtype = dtypes.canonicalize_dtype(kv.dtype) - batch_kv, kv_max_seqlen, nkv, num_head_kv, head_dim_kv = kv.shape - - assert q_dtype == kv_dtype - assert batch_q == batch_kv - assert num_head_q == num_head_kv - assert head_dim_q == head_dim_kv - assert nkv == 2 - assert q_cu_seqlen.dtype == kv_cu_seqlen.dtype - - output_shape = q.shape - output_dtype = q_dtype - softmax_aux_shape = (batch_q, num_head_q, q_max_seqlen, kv_max_seqlen) - softmax_aux_dtype = q_dtype - - return ( - ShapedArray(output_shape, output_dtype, named_shape=q.named_shape), # output - ShapedArray(softmax_aux_shape, softmax_aux_dtype, - named_shape=q.named_shape), # softmax_aux - ) + assert GatedGeluFp8Primitive.inner_primitive is not None + out, updated_amax = GatedGeluFp8Primitive.inner_primitive.bind(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype) + return out, updated_amax @staticmethod - def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): + def batcher(batched_args, batch_dims, *, out_dtype): """ - Cross fused attention fwd lowering rules + to describe batch rules for vmap """ - q_aval, kv_aval, _, _, _ = ctx.avals_in - assert q_aval.dtype == kv_aval.dtype + _check_valid_batch_dims(batch_dims) + assert GatedGeluFp8Primitive.outer_primitive is not None + x, amax, scale, scale_inv = batched_args + x_bdim, amax_bdim, _, _ = batch_dims - batch, q_max_seqlen, num_head, head_dim = q_aval.shape - kv_max_seqlen = kv_aval.shape[1] + out_bdims = x_bdim, amax_bdim + return GatedGeluFp8Primitive.outer_primitive.bind(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype), out_bdims - operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] + @staticmethod + def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + return (out_sharding, amax_sharding) - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training) + @staticmethod + def partition(out_dtype, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (out_sharding, amax_sharding) - out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) + def sharded_impl(x, amax, scale, scale_inv): + local_x, local_amax = GatedGeluFp8Primitive.impl(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) - return out + return local_x, global_updated_amax + return mesh, sharded_impl, out_shardings, arg_shardings -_cross_fused_attn_fwd_p = register_primitive(CrossFusedAttnFwdPrimitive) +register_primitive(GatedGeluFp8Primitive) -def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray, - kv_cu_seqlen: jnp.ndarray, seed: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): + +def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, + out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ - Wrapper for TE cross fused attention fwd - Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 + gated gelu wrapper + Return FP8(geglu(x)) """ - checker = _FusedAttnRNGStateChecker() - seed = checker.check_seed(seed, dropout_probability, is_training) - - return _cross_fused_attn_fwd_p.bind(q, - kv, - q_cu_seqlen, - kv_cu_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + return GatedGeluFp8Primitive.outer_primitive.bind(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype) -class CrossFusedAttnBwdPrimitive(BasePrimitive): +class DgatedGeluCastTransposePrimitive(BasePrimitive): """ - Cross Fused Attention Backward Primitive + Dgated Gelu Cast Transpose Primitive """ - name = "te_cross_fused_attn_backward" + name = "te_dgated_gelu_cast_transpose" multiple_results = True + impl_static_args = (5, 6) # out_dtype, static_axis_boundary + inner_primitive = None + outer_primitive = None @staticmethod - def abstract( - q, - kv, - softmax_aux, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - *, - attn_bias_type, # pylint: disable=unused-argument - attn_mask_type, # pylint: disable=unused-argument - scaling_factor, # pylint: disable=unused-argument - dropout_probability, # pylint: disable=unused-argument - is_training # pylint: disable=unused-argument - ): + def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, + static_axis_boundary): """ - Cross fused attention bwd abstract + te_dgated_gelu_cast_transpose_p abstract """ - q_dtype = dtypes.canonicalize_dtype(q.dtype) - kv_dtype = dtypes.canonicalize_dtype(kv.dtype) - softmax_aux_dtype = dtypes.canonicalize_dtype(softmax_aux.dtype) - doutput_dtype = dtypes.canonicalize_dtype(doutput.dtype) - assert q_dtype == kv_dtype == softmax_aux_dtype == doutput_dtype - assert q_cu_seqlen.dtype == kv_cu_seqlen.dtype + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dtype + assert x_aval.shape[-2] == 2 # Linear + GeLU + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_hidden_szie = dz_aval.shape[-1] + gi_hidden_size = x_aval.shape[-1] + assert ir_hidden_szie == gi_hidden_size + t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2) + out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) + t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + return out, t_out, updated_amax_aval - return ( - ShapedArray(q.shape, q_dtype, named_shape=q.named_shape), # dq - ShapedArray(kv.shape, kv_dtype, named_shape=kv.named_shape), # dkv - ) + @staticmethod + def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary): + """ + te_dgated_gelu_cast_transpose_p lowering rules + """ + dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dz_aval.dtype + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + x_batch_size = reduce(operator.mul, x_shape[:-2]) + assert dz_batch_szie == x_batch_size + assert x_shape[-2] == 2 # Linear + GeLU + ir_hidden_szie = ir_dz_shape[-1] + gi_hidden_size = x_shape[-1] + assert ir_hidden_szie == gi_hidden_size + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2) + out_types = [ + ir.RankedTensorType.get(x_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [dz, x, amax, scale, scale_inv] + operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + contracted_x_shape = (x_batch_size, x_shape[-1]) + opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype)) + + out = custom_caller(DgatedGeluCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={2: 2}) + + return out @staticmethod - def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): + def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary): """ - Cross fused attention bwd lowering rules + to describe implementation """ - q_aval, kv_aval, _, _, _, _ = ctx.avals_in - assert q_aval.dtype == kv_aval.dtype + assert DgatedGeluCastTransposePrimitive.inner_primitive is not None + out, t_out, updated_amax = DgatedGeluCastTransposePrimitive.inner_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary) + return out, t_out, updated_amax - batch, q_max_seqlen, num_head, head_dim = q_aval.shape - kv_max_seqlen = kv_aval.shape[1] + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary): + """ + to describe batch rules for vmap + """ + del static_axis_boundary + _check_valid_batch_dims(batch_dims) + assert DgatedGeluCastTransposePrimitive.outer_primitive is not None + dz, x, amax, scale, scale_inv = batched_args + x_bdim, _, amax_bdim, _, _ = batch_dims - operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] + out_bdims = x_bdim, x_bdim, amax_bdim + return DgatedGeluCastTransposePrimitive.outer_primitive.bind( + dz, x, amax, scale, scale_inv, out_dtype=out_dtype, + static_axis_boundary=x_bdim), out_bdims - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + @staticmethod + def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos, + result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[1]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2) + tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + return (out_sharding, tranposed_out_sharding, amax_sharding) - # the dropout elements are encoded in the forward auxiliary tensor - # so seed is not needed in backward - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training) + @staticmethod + def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[1]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) - return out + def sharded_impl(dz, x, amax, scale, scale_inv): + local_out, local_t_out, local_amax = DgatedGeluCastTransposePrimitive.impl( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + return local_out, local_t_out, global_updated_amax + return mesh, sharded_impl, out_shardings, arg_shardings -_cross_fused_attn_bwd_p = register_primitive(CrossFusedAttnBwdPrimitive) +register_primitive(DgatedGeluCastTransposePrimitive) -def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray, - doutput: jnp.ndarray, q_cu_seqlen: jnp.ndarray, kv_cu_seqlen: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): + +def dgated_gelu_cast_transpose( + dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, + scale_inv: jnp.ndarray, out_dtype: TEDType, + static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ - Wrapper for TE cross fused attention bwd - Return the gradients of cross fused attention with packed kv input + cast transpose d_gated_gelu fusion wrapper + Return FP8(dgeglu(inputs)) """ - return _cross_fused_attn_bwd_p.bind(q, - kv, - softmax_aux, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + return DgatedGeluCastTransposePrimitive.outer_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary) diff --git a/transformer_engine/jax/dot.py b/transformer_engine/jax/dot.py index 153732dc55..520b2df67b 100644 --- a/transformer_engine/jax/dot.py +++ b/transformer_engine/jax/dot.py @@ -4,227 +4,167 @@ """JAX te modules""" from typing import Tuple, Sequence -from functools import partial, reduce -import operator +from functools import partial import jax import jax.numpy as jnp -from transformer_engine_jax import DType as TEDType -from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype -from .fp8 import FP8Helper, FP8GemmPackage -from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta -from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources -from .sharding import xmap_runner, extend_fsdp_sharding_meta +from .cpp_extensions import cast_transpose +from .fp8 import FP8Helper, FP8MetaPackage -jax.config.update('experimental_xmap_spmd_lowering', True) -jax.config.update('experimental_xmap_spmd_lowering_manual', True) + +def type_safe_dot_general( + x, + kernel, + fp8_meta_pkg: FP8MetaPackage = None, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)) +) -> jnp.ndarray: + """ + Type safe dot_general, including FP8. + """ + + if fp8_meta_pkg is None: + kernel = jnp.asarray(kernel, x.dtype) + return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ()))) + + fp8_max = fp8_meta_pkg.fp8_max + amax = fp8_meta_pkg.amax + scale = fp8_meta_pkg.scale + scale_inv = fp8_meta_pkg.scale_inv + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + return _fp8_dot(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype, + contracting_dims) + + +def quantize(x, q_dtype, scale): + """ + Quantize with scale. + """ + dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype) + scale = scale.astype(x.dtype) + clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max) + return clipped_scaled_x.astype(q_dtype) -def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, - fwd_dtype: TEDType, - bwd_dtype: TEDType, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), - sharding_type: ShardingType = ShardingType.SINGLE, - dp_dim_index: int = 0) -> jnp.ndarray: +def dequantize(x, dq_dtype, scale_inv): + """ + Dequantize with scale_inv. + """ + return x.astype(dq_dtype) * scale_inv.astype(dq_dtype) + + +# Apply jit to guarantee correctness of FP8 GEMM. +@partial(jax.jit, static_argnums=(4, 5)) +def fp8_dot_impl( + q_lhs: jnp.ndarray, + q_rhs: jnp.ndarray, + lhs_scale_inv: jnp.ndarray, + rhs_scale_inv: jnp.ndarray, + ctype: jnp.dtype, # computing type + contracting_dims: Tuple[Sequence[int], Sequence[int]]): """ - FP8 dot wrapper + FP8 GEMM for XLA pattern match """ - assert fp8_gemm_pkg.num_of_gemm == 1 - inputs = fp8_gemm_pkg.inputs - kernel = fp8_gemm_pkg.kernels[0] - fp8_max = fp8_gemm_pkg.fp8_max - amax = fp8_gemm_pkg.amax - scale = fp8_gemm_pkg.scale - scale_inv = fp8_gemm_pkg.scale_inv - - if sharding_type is ShardingType.SINGLE: - res = _fp8_dot(inputs, - kernel, - fp8_max, - amax, - scale, - scale_inv, - fwd_dtype=fwd_dtype, - bwd_dtype=bwd_dtype, - contracting_dims=contracting_dims, - sharding_type=sharding_type, - dp_axis_name="", - tp_axis_name="", - fsdp_axis_name="") - else: - dp_axis_name = "batch" - tp_axis_name = "model" - kernel_tp_index = None - # TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme - if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL): - kernel_tp_index = len(kernel.shape) - 1 - elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW): - kernel_tp_index = 0 - - input_tp_index = len(inputs.shape) - 1 - sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape, - dp_dim_index, input_tp_index, kernel_tp_index, - contracting_dims, dp_axis_name, tp_axis_name) - sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index}) - inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input - kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel - - num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv - fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind, - dp_axis_name, tp_axis_name) - - axis_resources = merge_axis_resources( - [sharding_meta.axis_resources, fp8_sharding_meta.axis_resources]) - - partial_fp8_dot = partial(_fp8_dot, - fwd_dtype=fwd_dtype, - bwd_dtype=bwd_dtype, - contracting_dims=contracting_dims, - sharding_type=sharding_type, - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name, - fsdp_axis_name=fsdp_axis_name) - res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes), - sharding_meta.out_axes, axis_resources, - (inputs_, kernel_, fp8_max, amax, scale, scale_inv)) - - res = jnp.reshape(res, sharding_meta.output_shapes[0]) - - return res - - -@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12)) -def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType, - contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType, - dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str): - res, _ = _fp8_dot_fwd(inputs, - kernel, - fp8_maxs, - amax, - scale, - scale_inv, - fwd_dtype, - bwd_dtype, - contracting_dims=contracting_dims, - sharding_type=sharding_type, - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name, - fsdp_axis_name=fsdp_axis_name) - return res - - -def _fp8_dot_fwd( - inputs, + dim_nums = (contracting_dims, ((), ())) + + lhs = dequantize(q_lhs, ctype, lhs_scale_inv) + rhs = dequantize(q_rhs, ctype, rhs_scale_inv) + + return jax.lax.dot_general(lhs, rhs, dim_nums) + + +@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8)) +def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, + scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, + contracting_dims: Tuple[Sequence[int], Sequence[int]]): + output, _ = _fp8_dot_fwd_rule(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype, + contracting_dims) + return output + + +def _fp8_dot_fwd_rule( + x, kernel, - fp8_maxs, + fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype, # pylint: disable=unused-argument - contracting_dims, - sharding_type, - dp_axis_name, # pylint: disable=unused-argument - tp_axis_name, - fsdp_axis_name): # pylint: disable=unused-argument + contracting_dims): lhs_contracting_dims, rhs_contracting_dims = contracting_dims - input_shape_pre = inputs.shape[:min(lhs_contracting_dims)] - input_shape_suf = inputs.shape[min(lhs_contracting_dims):] + + x_shape_suf = x.shape[min(lhs_contracting_dims):] kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1] - kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:] - input_contracting_size = reduce(operator.mul, input_shape_suf) - kernel_contracting_size = reduce(operator.mul, kernel_shape_pre) - assert input_contracting_size == kernel_contracting_size - inputs_ = jnp.reshape(inputs, (-1, input_contracting_size)) - kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1)) + assert x_shape_suf == kernel_shape_pre amax = FP8Helper.update_amax_history(amax) - gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) + gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) - input_amax = amax[gemm_input_idx, 0:1] - input_scale = scale[gemm_input_idx] - input_scale_inv = scale_inv[gemm_input_idx] - input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale, - input_scale_inv, fwd_dtype) + x_amax = amax[gemm_x_idx, 0:1] + x_scale = scale[gemm_x_idx] + x_scale_inv = scale_inv[gemm_x_idx] + + casted_x, casted_xt, updated_x_amax = \ + cast_transpose(x, x_amax, x_scale, x_scale_inv, fwd_dtype, static_axis_boundary=-1, + transpose_axis_boundary=min(lhs_contracting_dims)) kernel_amax = amax[gemm_kernel_idx, 0:1] kernel_scale = scale[gemm_kernel_idx] kernel_scale_inv = scale_inv[gemm_kernel_idx] - kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale, - kernel_scale_inv, fwd_dtype) - res = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, input_cast, input_scale_inv, - fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP) - if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW): - res = jax.lax.psum(res, tp_axis_name) + casted_kerenl, casted_kerenl_t, updated_kernel_amax = \ + cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv, + fwd_dtype, static_axis_boundary=-1, + transpose_axis_boundary=(max(rhs_contracting_dims) + 1)) - # (input_shape_pre, input_shape_suf) - # x (kernel_shape_pre, kernel_shape_suf) - # = (input_shape_pre, kernel_shape_suf) - output_shape = input_shape_pre + kernel_shape_suf - res = jnp.reshape(res, output_shape) + rhs_t_contracting_dims = tuple(range(kernel.ndim - len(rhs_contracting_dims), kernel.ndim)) + output = fp8_dot_impl(casted_x, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype, + (lhs_contracting_dims, rhs_t_contracting_dims)) - ctx = (input_cast_trans, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax, - inputs.shape, kernel.shape) - return res, ctx + ctx = (casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax, + updated_kernel_amax, x.shape, kernel.shape) + return output, ctx -def _fp8_dot_bwd( - fwd_dtype, - bwd_dtype, - contracting_dims, # pylint: disable=unused-argument - sharding_type, - dp_axis_name, - tp_axis_name, - fsdp_axis_name, - ctx, - g): - input_cast_trans, kernel_cast, \ - fp8_maxs, amax, scale, scale_inv, \ - input_amax, kernel_amax, \ - inputs_shape, kernel_shape = ctx - - gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0) +def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument + lhs_contracting_dims, rhs_contracting_dims = contracting_dims + + casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, \ + updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx + + gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0) grad_amax = amax[gemm_grad_idx, 0:1] grad_scale = scale[gemm_grad_idx] grad_scale_inv = scale_inv[gemm_grad_idx] - g = jnp.reshape(g, (input_cast_trans.shape[1], -1)) - grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv, - bwd_dtype) - - input_scale_inv = scale_inv[gemm_input_idx] - wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, - True, input_cast_trans, input_scale_inv, fwd_dtype, False, - jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD) - kernel_scale_inv = scale_inv[gemm_kernel_idx] - dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, - bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) - - amax = amax.at[gemm_input_idx, 0].set(input_amax[0]) - amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0]) - amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0]) + casted_grad, casted_grad_t, updated_grad_amax = \ + cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, + bwd_dtype, static_axis_boundary=-1, + transpose_axis_boundary=min(lhs_contracting_dims)) - if is_dp_enabled(sharding_type.value[0]): - wgrad = jax.lax.psum(wgrad, dp_axis_name) - amax = jax.lax.pmax(amax, dp_axis_name) + xt_constracting_dim = tuple(range(len(lhs_contracting_dims), len(x_shape))) + gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim)) + x_scale_inv = scale_inv[gemm_x_idx] + wgrad = fp8_dot_impl(casted_xt, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype, + (xt_constracting_dim, gt_constracting_dim)) - if len(fsdp_axis_name) > 0: - wgrad = jax.lax.psum(wgrad, fsdp_axis_name) - amax = jax.lax.pmax(amax, fsdp_axis_name) + g_constracting_dim = tuple( + range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim)) + k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape))) + kernel_scale_inv = scale_inv[gemm_kernel_idx] + dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype, + (g_constracting_dim, k_constracting_dim)) - if is_tp_enabled(sharding_type.value[0]): - amax = jax.lax.pmax(amax, tp_axis_name) + amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0]) + amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0]) + amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0]) - if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL): - dgrad = jax.lax.psum(dgrad, tp_axis_name) + scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) - dgrad = jnp.reshape(dgrad, inputs_shape) - wgrad = jnp.reshape(wgrad, kernel_shape) - return dgrad, wgrad, fp8_maxs, amax, scale, scale_inv + return dgrad, wgrad, fp8_max, amax, scale, scale_inv -_fp8_dot.defvjp(_fp8_dot_fwd, _fp8_dot_bwd) +_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 6ef520a8c2..7d80be5878 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -6,6 +6,7 @@ """ import functools import operator +import warnings from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union import jax.numpy as jnp @@ -16,14 +17,12 @@ from jax import nn as jax_nn from jax import random as jax_random -from ..dot import fp8_dot -from ..fp8 import FP8GemmPackage, FP8Helper +from ..dot import type_safe_dot_general +from ..fp8 import FP8Helper, FP8MetaPackage from ..layernorm import canonicalize_layernorm_type from ..layernorm import layernorm, layernorm_fp8_dot -from ..mlp import fp8_ln_mlp, geglu -from ..sharding import infer_sharding_type +from ..mlp import layernrom_geglu_fp8_mlp, geglu from ..softmax import is_softmax_kernel_available -from ..sharding import MajorShardingType, ShardingType from ..softmax import softmax, SoftmaxType PRNGKey = Any @@ -119,16 +118,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods Scalar for the input to softmax. softmax_type : SoftmaxType, default = SoftmaxType.SCALED Indicate the type of softmax. - - Optimization parameters - ----------------------- - sharding_type : ShardingType, default = ShardingType.SINGLE - Indicate the sharding pattern. """ scale_factor: float = 1.0 softmax_type: SoftmaxType = SoftmaxType.SCALED - sharding_type: ShardingType = ShardingType.SINGLE @nn.compact def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray: @@ -149,8 +142,7 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp if self.softmax_type is not SoftmaxType.SCALED_MASKED: mask_ = None - outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type, - self.sharding_type) + outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type) else: attention_bias = None if mask is not None: @@ -168,8 +160,7 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp # and kernel is unavailable, then try on pure scaled softmax custom calls. if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype): - outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED, - self.sharding_type) + outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) else: outputs = jax_nn.softmax(logits * self.scale_factor) @@ -242,8 +233,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). - sharding_type : ShardingType, default = ShardingType.SINGLE - Indicate the sharding pattern. """ epsilon: float = 1e-6 layernorm_type: str = 'layernorm' @@ -254,7 +243,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods bias_axes: Tuple[str, ...] = ('embed',) dtype: DType = jnp.float32 transpose_batch_sequence: bool = False - sharding_type: ShardingType = ShardingType.SINGLE + sharding_type = None def __post_init__(self): self.scale_init = _obtain_default_layernorm_scale_init_if_need( @@ -276,6 +265,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: outputs : jax.numpy.ndarray Output tensors. """ + warnings.warn("sharding_type of LayerNorm would be removed in the near feature", + DeprecationWarning) features = x.shape[-1] scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), @@ -286,9 +277,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: ln_bias, layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - sharding_type=self.sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) + epsilon=self.epsilon) class TransformerEngineBase(nn.Module): @@ -329,17 +318,15 @@ def get_fp8_metas(num_of_gemm: int) -> List[jnp.ndarray]: return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value @staticmethod - def get_fp8_gemm_package(num_of_gemm: int, inputs: jnp.ndarray, - kernels: List[jnp.ndarray]) -> FP8GemmPackage: + def get_fp8_meta_package(num_of_gemm: int) -> FP8MetaPackage: """ Get the FP8 metas """ - assert num_of_gemm == len(kernels) fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \ TransformerEngineBase.get_fp8_metas(num_of_gemm) - return FP8GemmPackage(num_of_gemm, inputs, kernels, fp8_max, fp8_metas_amax, - fp8_metas_scale, fp8_metas_scale_inv) + return FP8MetaPackage(num_of_gemm, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv) class DenseGeneral(TransformerEngineBase): @@ -376,8 +363,6 @@ class DenseGeneral(TransformerEngineBase): Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). - sharding_type : ShardingType, default = ShardingType.SINGLE - Indicate the sharding pattern. """ features: Union[Iterable[int], int] @@ -389,7 +374,7 @@ class DenseGeneral(TransformerEngineBase): axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = False - sharding_type: ShardingType = ShardingType.SINGLE + sharding_type = None def __post_init__(self): if self.kernel_init is None: @@ -411,6 +396,9 @@ def __call__(self, inputs: Array) -> Array: outputs : jax.numpy.ndarray Output tensors. """ + warnings.warn("sharding_type of DenseGeneral would be removed in the near feature", + DeprecationWarning) + features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) @@ -438,18 +426,15 @@ def __call__(self, inputs: Array) -> Array: bias = None contract_ind = tuple(range(0, len(axis))) - + fp8_gemm_pkg = None if FP8Helper.is_fp8_enabled(): - fp8_gemm_package = \ - TransformerEngineBase.get_fp8_gemm_package(1, inputs, [kernel]) - y = fp8_dot(fp8_gemm_package, - FP8Helper.FWD_DTYPE, - FP8Helper.BWD_DTYPE, (axis, contract_ind), - sharding_type=self.sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) - else: - kernel = jnp.asarray(kernel, self.dtype) - y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) + fp8_gemm_pkg = \ + TransformerEngineBase.get_fp8_meta_package(1) + + y = type_safe_dot_general(inputs, + kernel, + fp8_meta_pkg=fp8_gemm_pkg, + contracting_dims=(axis, contract_ind)) if bias is not None: bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape @@ -528,8 +513,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): depth_scaling: float, default = None The factor to scale the output from `DenseGeneral`. It should be a float value or None. When None is set, then no scaling is applied. - sharding_type : ShardingType, default = ShardingType.SINGLE - Indicate the sharding pattern. """ features: Union[Iterable[int], int] @@ -551,7 +534,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): dtype: DType = jnp.float32 transpose_batch_sequence: bool = True depth_scaling: float = None - sharding_type: ShardingType = ShardingType.SINGLE + sharding_type = None def __post_init__(self): if self.kernel_init is None: @@ -578,12 +561,16 @@ def __call__(self, inputs: Array) -> Array: The output tensors of layer normalization. If :attr:`return_layernorm_output=False`, then this would be None. """ + warnings.warn("sharding_type of LayerNormDenseGeneral would be removed in the near feature", + DeprecationWarning) + ln_output = None fuse_layernorm = FP8Helper.is_fp8_enabled( ) and not self.return_layernorm_output and self.enable_layernorm if self.enable_layernorm: + assert self.axis == -1 # Only support axis = =-1 at this moment features = inputs.shape[-1] scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), @@ -597,9 +584,7 @@ def __call__(self, inputs: Array) -> Array: ln_bias, layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - sharding_type=self.sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) + epsilon=self.epsilon) else: assert not self.return_layernorm_output y = inputs @@ -627,30 +612,25 @@ def __call__(self, inputs: Array) -> Array: contract_ind = tuple(range(0, len(axis))) + fp8_meta_package = None if FP8Helper.is_fp8_enabled(): - fp8_gemm_package = \ - TransformerEngineBase.get_fp8_gemm_package(1, y, [kernel]) - - if not fuse_layernorm: - z = fp8_dot(fp8_gemm_package, - FP8Helper.FWD_DTYPE, - FP8Helper.BWD_DTYPE, (axis, contract_ind), - sharding_type=self.sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) - else: - z = layernorm_fp8_dot(fp8_gemm_package, - scale, - ln_bias, - self.layernorm_type, - FP8Helper.FWD_DTYPE, - FP8Helper.BWD_DTYPE, (axis, contract_ind), - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - sharding_type=self.sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) + fp8_meta_package = \ + TransformerEngineBase.get_fp8_meta_package(1) + + if fuse_layernorm: + z = layernorm_fp8_dot(y, + kernel, + scale, + ln_bias, + fp8_meta_package, + self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.epsilon) else: - kernel = jnp.asarray(kernel, self.dtype) - z = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ()))) + z = type_safe_dot_general(y, + kernel, + fp8_meta_pkg=fp8_meta_package, + contracting_dims=(axis, contract_ind)) bias = None if self.use_bias: @@ -758,8 +738,6 @@ class LayerNormMLP(TransformerEngineBase): Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). - major_sharding_type : MajorShardingType, default = MajorShardingType.SINGLE - Indicate the sharding pattern. """ intermediate_dim: int = 2048 @@ -776,10 +754,7 @@ class LayerNormMLP(TransformerEngineBase): kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed') use_bias: bool = False bias_init: Initializer = nn.initializers.zeros - bias_axes_1: Tuple[str, ...] = ( - 'act', - 'mlp', - ) + bias_axes_1: Tuple[str, ...] = ('act', 'mlp') bias_axes_2: Tuple[str, ...] = ('embed',) return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ('relu',) @@ -789,7 +764,7 @@ class LayerNormMLP(TransformerEngineBase): axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = True - major_sharding_type: MajorShardingType = MajorShardingType.SINGLE + major_sharding_type = None def __post_init__(self): if self.kernel_init is None: @@ -818,19 +793,32 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: The output tensors of layer normalization. If :attr:`return_layernorm_output=False`, then this would be None. """ + warnings.warn("major_sharding_type of LayerNormMLP would be removed in the near feature", + DeprecationWarning) + ln_output = None fuse_layernorm = FP8Helper.is_fp8_enabled( ) and not self.return_layernorm_output and self.enable_layernorm + def is_geglu(acts): + geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')] + + normalize_acts = [] + for act in acts: + if not isinstance(act, str): + return False + normalize_acts.append(act.lower()) + return normalize_acts in geglu_act_pool + use_fused_ln_mlp = fuse_layernorm \ - and (not self.use_bias) and self.activations == ('gelu', 'linear') \ + and (not self.use_bias) and is_geglu(self.activations) \ and (self.intermediate_dropout_rate < 1e-3) - first_sharding_type, second_sharding_type = infer_sharding_type(self.major_sharding_type) - # LayerNorm if self.enable_layernorm: + assert self.axis == -1 # Only support axis == -1 at this moment + features = inputs.shape[-1] scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), @@ -844,9 +832,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ln_bias, layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - sharding_type=first_sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) + epsilon=self.epsilon) else: assert not self.return_layernorm_output y = inputs @@ -864,107 +850,67 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32) num_of_gemm = 2 - if use_fused_ln_mlp: - num_activations = len(self.activations) - axis = _canonicalize_tuple(self.axis) - axis = _normalize_axes(axis, inputs.ndim) - - intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim)) - kernel_1_shape = tuple(inputs.shape[ax] for ax in axis) + intermediate_dim - kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) - kernel_1 = nn_partitioning.param_with_axes('wi_kernel', - kernel_1_init, - num_activations, - -2, - kernel_1_each_shape, - jnp.float32, - axes=self.kernel_axes_1) - kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) - hidden_size = inputs.shape[-1] - hidden_size_tuple = _canonicalize_tuple(hidden_size) - kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple - kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple)) - kernel_2 = nn_partitioning.param_with_axes('wo_kernel', - self.kernel_init, - kernel_2_param_shape, - jnp.float32, - axes=self.kernel_axes_2) - kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) - contract_ind = tuple(range(0, len(axis))) - - fp8_gemm_package = \ - TransformerEngineBase.get_fp8_gemm_package(num_of_gemm, y, [kernel_1, kernel_2]) - out = fp8_ln_mlp(fp8_gemm_package, - scale, - ln_bias, - self.layernorm_type, - FP8Helper.FWD_DTYPE, - FP8Helper.BWD_DTYPE, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - contracting_dims=(axis, contract_ind), - major_sharding_type=self.major_sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0, - activations=self.activations) - else: # not use_fused_ln_mlp + fp8_meta_package = None + if FP8Helper.is_fp8_enabled(): + fp8_meta_package = \ + TransformerEngineBase.get_fp8_meta_package(num_of_gemm) - def fp8_meta_generator(): - fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = (None, None, None, - None) - if FP8Helper.is_fp8_enabled(): - fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \ - TransformerEngineBase.get_fp8_metas(num_of_gemm) - return fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv + num_activations = len(self.activations) + axis = _canonicalize_tuple(self.axis) + axis = _normalize_axes(axis, y.ndim) - fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \ - fp8_meta_generator() + intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim)) + kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim + kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) + kernel_1 = nn_partitioning.param_with_axes('wi_kernel', + kernel_1_init, + num_activations, + -2, + kernel_1_each_shape, + jnp.float32, + axes=self.kernel_axes_1) + kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) + hidden_size = inputs.shape[-1] + hidden_size_tuple = _canonicalize_tuple(hidden_size) + kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple + kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple)) + kernel_2 = nn_partitioning.param_with_axes('wo_kernel', + self.kernel_init, + kernel_2_param_shape, + jnp.float32, + axes=self.kernel_axes_2) + kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) + contract_ind = tuple(range(0, len(axis))) - # DenseGeneral 1 - activations = [] - num_activations = len(self.activations) - axis = _canonicalize_tuple(self.axis) - axis = _normalize_axes(axis, y.ndim) - - intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim)) - kernel_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim - kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) - kernel = nn_partitioning.param_with_axes('wi_kernel', - kernel_1_init, - num_activations, - -2, - kernel_1_each_shape, - jnp.float32, - axes=self.kernel_axes_1) - kernel = jnp.reshape(kernel, kernel_shape) - contract_ind = tuple(range(0, len(axis))) - - if FP8Helper.is_fp8_enabled(): - fp8_gemm_package = FP8GemmPackage( - 1, y, [kernel], fp8_max[:FP8Helper.NUM_META_PER_GEMM, :], - fp8_metas_amax[:FP8Helper.NUM_META_PER_GEMM, :], - fp8_metas_scale[:FP8Helper.NUM_META_PER_GEMM, :], - fp8_metas_scale_inv[:FP8Helper.NUM_META_PER_GEMM, :]) - - if not fuse_layernorm: - x = fp8_dot(fp8_gemm_package, - FP8Helper.FWD_DTYPE, - FP8Helper.BWD_DTYPE, (axis, contract_ind), - sharding_type=first_sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) - else: - x = layernorm_fp8_dot(fp8_gemm_package, + if use_fused_ln_mlp: + assert self.axis == -1 # Only support axis = =-1 at this moment + + out = layernrom_geglu_fp8_mlp(y, scale, - ln_bias, + ln_bias, [kernel_1, kernel_2], + fp8_meta_package, self.layernorm_type, - FP8Helper.FWD_DTYPE, - FP8Helper.BWD_DTYPE, (axis, contract_ind), zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - sharding_type=first_sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) - else: # not enable fp8 - kernel = jnp.asarray(kernel, self.dtype) - x = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ()))) + epsilon=self.epsilon) + else: # not use_fused_ln_mlp + + # DenseGeneral 1 + gemm1_fp8_meta_package = None if fp8_meta_package is None \ + else fp8_meta_package.get_package_by_gemm_idx(0) + if fuse_layernorm: + x = layernorm_fp8_dot(y, + kernel_1, + scale, + ln_bias, + gemm1_fp8_meta_package, + self.layernorm_type, + zero_centered_gamma=self.zero_centered_gamma, + epsilon=self.epsilon) + else: + x = type_safe_dot_general(y, + kernel_1, + fp8_meta_pkg=gemm1_fp8_meta_package, + contracting_dims=(axis, contract_ind)) bias = None if self.use_bias: @@ -977,11 +923,9 @@ def fp8_meta_generator(): bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape x += jnp.reshape(bias, bias_shape) - if self.activations == ('gelu', 'linear'): - z = geglu(x, - contracting_dims=(-2, -1), - sharding_type=second_sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) + activations = [] + if is_geglu(self.activations): + z = geglu(x) else: x = jnp.split(x, num_activations, axis=-2) for idx, act_fn in enumerate(self.activations): @@ -996,37 +940,13 @@ def fp8_meta_generator(): z, deterministic=deterministic) # DenseGeneral 2 - hidden_size = inputs.shape[-1] - hidden_size_tuple = _canonicalize_tuple(hidden_size) - axis = _canonicalize_tuple(self.axis) - axis = _normalize_axes(axis, z.ndim) - - kernel_shape = tuple(z.shape[ax] for ax in axis) + hidden_size_tuple - kernel_param_shape = (np.prod([z.shape[ax] for ax in axis]), np.prod(hidden_size_tuple)) - kernel = nn_partitioning.param_with_axes('wo_kernel', - self.kernel_init, - kernel_param_shape, - jnp.float32, - axes=self.kernel_axes_2) - kernel = jnp.reshape(kernel, kernel_shape) - - contract_ind = tuple(range(0, len(axis))) - - if FP8Helper.is_fp8_enabled(): - fp8_gemm_package = FP8GemmPackage( - 1, z, [kernel], fp8_max[FP8Helper.NUM_META_PER_GEMM:, :], - fp8_metas_amax[FP8Helper.NUM_META_PER_GEMM:, :], - fp8_metas_scale[FP8Helper.NUM_META_PER_GEMM:, :], - fp8_metas_scale_inv[FP8Helper.NUM_META_PER_GEMM:, :]) - - out = fp8_dot(fp8_gemm_package, - FP8Helper.FWD_DTYPE, - FP8Helper.BWD_DTYPE, (axis, contract_ind), - sharding_type=second_sharding_type, - dp_dim_index=1 if self.transpose_batch_sequence else 0) - else: - kernel = jnp.asarray(kernel, self.dtype) - out = lax.dot_general(z, kernel, ((axis, contract_ind), ((), ()))) + gemm2_fp8_meta_package = None if fp8_meta_package is None \ + else fp8_meta_package.get_package_by_gemm_idx(1) + + out = type_safe_dot_general(z, + kernel_2, + fp8_meta_pkg=gemm2_fp8_meta_package, + contracting_dims=(axis, contract_ind)) bias = None if self.use_bias: diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index a21b9901ea..989b060696 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -27,9 +27,8 @@ from ..fused_attn import is_fused_attn_kernel_available from ..fused_attn import self_fused_attn, cross_fused_attn from ..softmax import SoftmaxType -from ..sharding import infer_major_sharding_type, infer_sharding_type -from ..sharding import global_shard_resource, with_sharding_constraint -from ..sharding import ShardingType +from ..sharding import global_mesh_resource, num_of_devices +from ..sharding import with_sharding_constraint PRNGKey = Any Shape = Tuple[int, ...] @@ -102,7 +101,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: else: rules_map[key] = [val] - gsr = global_shard_resource() + gsr = global_mesh_resource() batch_dim_rule = [] if gsr.dp_resource is not None: @@ -186,7 +185,6 @@ def core_attention(query: Array, scale_factor: float, transpose_batch_sequence: bool, softmax_type: SoftmaxType = SoftmaxType.SCALED, - softmax_sharding_type: ShardingType = ShardingType.SINGLE, mask: Optional[Array] = None, bias: Optional[Array] = None, dropout_rng: Optional[PRNGKey] = None, @@ -226,9 +224,7 @@ def core_attention(query: Array, fused_scale_factor = scale_factor attn_weights = Softmax(softmax_type=softmax_type, - scale_factor=fused_scale_factor, - sharding_type=softmax_sharding_type)(attn_weights, mask, - bias).astype(dtype) + scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype) if not deterministic and dropout_rate > 0.: keep_prob = 1.0 - dropout_rate @@ -482,8 +478,6 @@ def _check_head_dim(head_dim): f"Fused attention is not enabled. Because " \ f"{reason}fall back to unfused attention.") - first_sharding_type, second_sharding_type = infer_sharding_type() - residual = inputs_q if self.fuse_qkv: if is_self_attn: @@ -494,7 +488,6 @@ def _check_head_dim(head_dim): epsilon=self.layernorm_epsilon, axis=-1, features=(3, self.num_heads * self.head_dim), - sharding_type=first_sharding_type, transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=self.apply_residual_connection_post_layernorm, scale_axes=(W_NO_SHARD_AXES,), @@ -516,7 +509,6 @@ def _check_head_dim(head_dim): epsilon=self.layernorm_epsilon, axis=-1, features=self.num_heads * self.head_dim, - sharding_type=first_sharding_type, transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=self.apply_residual_connection_post_layernorm, scale_axes=(W_NO_SHARD_AXES,), @@ -530,7 +522,6 @@ def _check_head_dim(head_dim): name='query')(inputs_q) kv_proj = DenseGeneral(axis=-1, features=(2, self.num_heads * self.head_dim), - sharding_type=first_sharding_type, transpose_batch_sequence=self.transpose_batch_sequence, kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_init=kv_init, @@ -546,7 +537,6 @@ def _check_head_dim(head_dim): DenseGeneral, axis=-1, features=self.num_heads * self.head_dim, - sharding_type=first_sharding_type, transpose_batch_sequence=self.transpose_batch_sequence, kernel_axes=(W_FSDP_AXES, W_TP_AXES), use_bias=self.use_bias, @@ -560,7 +550,6 @@ def _check_head_dim(head_dim): epsilon=self.layernorm_epsilon, axis=-1, features=self.num_heads * self.head_dim, - sharding_type=first_sharding_type, transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=True, scale_axes=(W_NO_SHARD_AXES,), @@ -648,7 +637,7 @@ def _check_head_dim(head_dim): seed = None if dropout_rng is not None: - seed = jax.random.split(dropout_rng, len(jax.devices())) + seed = jax.random.split(dropout_rng, num_of_devices()) # ensure the old key never used del dropout_rng @@ -665,8 +654,7 @@ def _check_head_dim(head_dim): attn_mask_type=attn_mask_type, scaling_factor=scale_factor, dropout_probability=self.dropout_rate, - is_training=not deterministic, - sharding_type=first_sharding_type) + is_training=not deterministic) else: assert bias is None query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim)) @@ -685,8 +673,7 @@ def _check_head_dim(head_dim): attn_mask_type=attn_mask_type, scaling_factor=scale_factor, dropout_probability=self.dropout_rate, - is_training=not deterministic, - sharding_type=first_sharding_type) + is_training=not deterministic) else: def convert_to_softmax_type(attn_mask_type, mask): @@ -710,7 +697,6 @@ def convert_to_softmax_type(attn_mask_type, mask): scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, softmax_type=softmax_type, - softmax_sharding_type=first_sharding_type, mask=mask, bias=bias, dropout_rng=dropout_rng, @@ -728,7 +714,6 @@ def convert_to_softmax_type(attn_mask_type, mask): x = _with_sharding_constraint(x, attn_context_sharding_constraint) out = DenseGeneral(features=inputs_q.shape[-1], - sharding_type=second_sharding_type, transpose_batch_sequence=self.transpose_batch_sequence, axis=-1, kernel_init=self.kernel_init, @@ -1175,7 +1160,6 @@ def hidden_dropout(x, deterministic): layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon, - major_sharding_type=infer_major_sharding_type(), transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, @@ -1208,7 +1192,6 @@ def hidden_dropout(x, deterministic): z = z + residual if self.output_layernorm: - ln_sharding_type, _ = infer_sharding_type() z = LayerNorm(layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon, @@ -1216,7 +1199,6 @@ def hidden_dropout(x, deterministic): bias_axes=(W_NO_SHARD_AXES,), transpose_batch_sequence=self.transpose_batch_sequence, dtype=self.dtype, - sharding_type=ln_sharding_type, name="output_layer_norm")(z) return z diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index c64bcbd6d0..183519c995 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -6,7 +6,7 @@ """ from contextlib import contextmanager from enum import Enum -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -17,7 +17,7 @@ from transformer_engine_jax import get_cuda_version, get_device_compute_capability from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.jax.sharding import global_shard_guard -from transformer_engine.jax.sharding import ShardingResource +from transformer_engine.jax.sharding import MeshResource _is_fp8_available = None _reason_for_no_fp8 = "" @@ -59,37 +59,29 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]: def _format2dtypes(format_: Format): if format_ == Format.E4M3: - return DType.kFloat8E4M3, DType.kFloat8E4M3 + return jnp.float8_e4m3fn, jnp.float8_e4m3fn if format_ == Format.E5M2: - return DType.kFloat8E5M2, DType.kFloat8E5M2 + return jnp.float8_e5m2, jnp.float8_e5m2 if format_ == Format.HYBRID: - return DType.kFloat8E4M3, DType.kFloat8E5M2 - return DType.kBFloat16, DType.kBFloat16 + return jnp.float8_e4m3fn, jnp.float8_e5m2 + return jnp.bfloat16, jnp.bfloat16 -class FP8GemmPackage: +class FP8MetaPackage: """ - A container that contains all required data for - FP8 GEMM + A container that contains all required meta data for FP8 """ def __init__( self, num_of_gemm: int, - inputs: jnp.ndarray, - kernels: List[jnp.ndarray], fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, ) -> None: + total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM self._num_of_gemm = num_of_gemm - self._inputs = inputs - - assert len(kernels) == self._num_of_gemm - self._kernels = kernels - - total_num_of_meta = self._num_of_gemm * FP8Helper.NUM_META_PER_GEMM assert fp8_max.shape[0] == total_num_of_meta self._fp8_max = fp8_max assert amax.shape[0] == total_num_of_meta @@ -106,20 +98,6 @@ def num_of_gemm(self) -> int: """ return self._num_of_gemm - @property - def inputs(self) -> jnp.ndarray: - """ - inputs of this package - """ - return self._inputs - - @property - def kernels(self) -> List[jnp.ndarray]: - """ - kernels of this package - """ - return self._kernels - @property def fp8_max(self) -> jnp.ndarray: """ @@ -148,6 +126,19 @@ def scale_inv(self) -> jnp.ndarray: """ return self._scale_inv + def get_package_by_gemm_idx(self, gemm_idx): + """ + Get a sub package by gemm_idx + """ + assert self.num_of_gemm > gemm_idx + + meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM + meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM + return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx], + self.amax[meta_start_idx:meta_end_idx], + self.scale[meta_start_idx:meta_end_idx], + self.scale_inv[meta_start_idx:meta_end_idx]) + class AmaxComputeAlgo(Enum): """AmaxComputeAlgo.""" @@ -155,6 +146,9 @@ class AmaxComputeAlgo(Enum): MOST_RECENT = "most_recent" +NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection" + + class FP8Helper: """ FP8 helper to manage the FP8 meta @@ -162,8 +156,8 @@ class FP8Helper: INITIALIZED = False MARGIN: float = 0.0 FP8_FORMAT: Format = Format.HYBRID - FWD_DTYPE: DType = DType.kFloat8E4M3 - BWD_DTYPE: DType = DType.kFloat8E5M2 + FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0] + BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1] UPDATE_FP8META_INTERVAL: int = 1 AMAX_HISTORY_LEN: int = 1024 AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX @@ -171,7 +165,7 @@ class FP8Helper: INPUT_META_IDX_PER_GEMM: int = 0 KERNEL_META_IDX_PER_GEMM: int = 1 GRAD_META_IDX_PER_GEMM: int = 2 - FP8_COLLECTION_NAME: str = "fp8_meta_collection" + FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME FP8_AMAX_NAME: str = "fp8_meta_amax" FP8_SCALE_NAME: str = "fp8_meta_scale" FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv" @@ -216,21 +210,12 @@ def finalize() -> None: FP8Helper.INITIALIZED = False FP8Helper.MARGIN = 0.0 FP8Helper.FP8_FORMAT = Format.HYBRID - FP8Helper.FWD_DTYPE = DType.kFloat8E4M3 - FP8Helper.BWD_DTYPE = DType.kFloat8E5M2 + FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \ + _format2dtypes(FP8Helper.FP8_FORMAT) FP8Helper.UPDATE_FP8META_INTERVAL = 1 FP8Helper.AMAX_HISTORY_LEN = 1024 FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX - @staticmethod - def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray: - """ - Update the amax history - """ - updated_amax_buffers = jnp.roll(amax_buffers, -1, 1) - updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0) - return updated_amax_buffers - @staticmethod def update_collections(new: Collection, original: Collection) -> Collection: """ @@ -270,8 +255,8 @@ def generate_fp8_max_array(num_of_meta): Generate the FP8 max array """ num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM - fp8_max_fwd = FP8Helper.FP8_FORMAT.value.max_fwd - fp8_max_bwd = FP8Helper.FP8_FORMAT.value.max_bwd + fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max + fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max fp8_max_per_gemm = [] for i in range(FP8Helper.NUM_META_PER_GEMM): val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \ @@ -318,11 +303,40 @@ def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection: return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays) + @staticmethod + def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray: + """ + Update the amax history + """ + updated_amax = jnp.roll(amax, -1, -1) + updated_amax = updated_amax.at[..., 0].set(0) + return updated_amax + + @staticmethod + @jax.jit + def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray, + scale: jnp.ndarray) -> jnp.ndarray: + """ + Calculate fp8 scale and scale_inv based on given amax. + """ + if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: + amax = jnp.max(amax, axis=-1, keepdims=True) + else: + amax = amax[..., 0:1] + + sf = (fp8_max / amax) / (2**FP8Helper.MARGIN) + sf = jnp.where(amax > 0.0, sf, scale) + sf = jnp.where(jnp.isfinite(amax), sf, scale) + scale = sf + scale_inv = 1 / sf + + return scale, scale_inv + @contextmanager def fp8_autocast(enabled: bool = False, fp8_recipe: Optional[DelayedScaling] = None, - sharding_resource: Optional[ShardingResource] = None) -> None: + mesh_resource: Optional[MeshResource] = None) -> None: r""" Context manager for FP8 usage. @@ -334,9 +348,9 @@ def fp8_autocast(enabled: bool = False, devices = np.asarray(jax.devices()).reshape(*mesh_shape) with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): - sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name) + mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) - with fp8_autocast(enabled=True, sharding_resource=sharding_resource): + with fp8_autocast(enabled=True, mesh_resource=mesh_resource): rules = extend_logical_axis_rules(tuple()) transformer = TransformerLayer() @@ -356,7 +370,7 @@ def fp8_autocast(enabled: bool = False, Whether or not to enable fp8 fp8_recipe: recipe.DelayedScaling, default = None Recipe used for FP8 training. - sharding_resource: ShardingResource, default = None + mesh_resource: MeshResource, default = None Specify the mesh axes for data and tensor parallelism to shard along. If set to None, then no data or tensor parallelism will be used. @@ -373,11 +387,11 @@ def fp8_autocast(enabled: bool = False, "DelayedScaling override_linear_precision isn't supported by TE/JAX.") assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.") - if sharding_resource is None: - sharding_resource = ShardingResource() + if mesh_resource is None: + mesh_resource = MeshResource() try: - with global_shard_guard(sharding_resource): + with global_shard_guard(mesh_resource): if enabled: fp8_available, reason_for_no_fp8 = is_fp8_available() assert fp8_available, reason_for_no_fp8 diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index a8a6421a89..8c0a31556d 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -15,12 +15,6 @@ from .cpp_extensions import FusedAttnHelper from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd -from .sharding import get_fused_attn_sharding_meta -from .sharding import ShardingType -from .sharding import xmap_runner, extend_fsdp_sharding_meta - -jax.config.update('experimental_xmap_spmd_lowering', True) -jax.config.update('experimental_xmap_spmd_lowering_manual', True) class AttnBiasType(Enum): @@ -54,62 +48,24 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, head_dim).is_fused_attn_kernel_available() -def self_fused_attn(qkv: jnp.ndarray, - bias: jnp.ndarray, - mask: jnp.ndarray, - seed: jnp.ndarray, - attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, - scaling_factor: float, - dropout_probability: float, - is_training: bool, - sharding_type: ShardingType = ShardingType.SINGLE): +def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, + attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, + scaling_factor: float, dropout_probability: float, is_training: bool): """ Self fused attention wrapper """ - assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ - "self_fused_attn does not support row-split tensor parallelism currently." - - if sharding_type is ShardingType.SINGLE: - output = _self_fused_attn(qkv, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - else: - dp_axis_name = "batch" - tp_axis_name = "model" - - inputs = [qkv, bias, mask, seed] - batch, seqlen, _, num_head, head_dim = qkv.shape - output_shape = [batch, seqlen, num_head, head_dim] - sharding_meta = get_fused_attn_sharding_meta( - sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape], - dp_dims=([0, None, 0, 0], [0]), - tp_dims=([3, 1, None, 0], [2]), - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name) - sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0}) - - inputs_ = tuple( - jnp.reshape(x, new_shape) if x is not None else None - for x, new_shape in zip(inputs, sharding_meta.input_shapes)) - - partial_self_fused_attn = partial(_self_fused_attn, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes, - sharding_meta.out_axes, sharding_meta.axis_resources, inputs_) - - output = jnp.reshape(output_, sharding_meta.output_shapes) + assert attn_mask_type is not AttnMaskType.NO_MASK, \ + "Currently not support AttnMaskType.NO_MASK." + + output = _self_fused_attn(qkv, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) return output @@ -118,119 +74,70 @@ def self_fused_attn(qkv: jnp.ndarray, def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float, is_training: bool): - output, _ = _self_fused_attn_fwd(qkv, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return output + output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training) + return output -def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - - seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) - cu_seqlen = jnp.cumsum(seqlen) - cu_seqlen = jnp.hstack((0, cu_seqlen)) +def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): + squeezed_mask = mask[:, :, :, 0] output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, bias, - cu_seqlen, + squeezed_mask, seed, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) - return output, (qkv, softmax_aux, rng_state, output, cu_seqlen) + return output, (qkv, softmax_aux, rng_state, output, squeezed_mask) -def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - is_training, ctx, grad): - qkv, softmax_aux, rng_state, output, cu_seqlen = ctx - - doutput = grad +def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, + is_training, ctx, dz): + qkv, softmax_aux, rng_state, output, squeezed_mask = ctx grad_qkv, grad_bias = self_fused_attn_bwd(qkv, softmax_aux, rng_state, output, - doutput, - cu_seqlen, + dz, + squeezed_mask, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None return grad_qkv, grad_bias, None, None -_self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd) +_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule) -def cross_fused_attn(q: jnp.ndarray, - kv: jnp.ndarray, - mask: jnp.ndarray, - seed: jnp.ndarray, - attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, - scaling_factor: float, - dropout_probability: float, - is_training: bool, - sharding_type: ShardingType = ShardingType.SINGLE): +def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, + attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, + scaling_factor: float, dropout_probability: float, is_training: bool): """ Cross multi-head attention wrapper """ - assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ - "cross_fused_attn does not support row-split tensor parallelism currently." - - if sharding_type is ShardingType.SINGLE: - output = _cross_fused_attn(q, - kv, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - else: - dp_axis_name = "batch" - tp_axis_name = "model" - - inputs = [q, kv, mask, seed] - output_shape = q.shape - sharding_meta = get_fused_attn_sharding_meta( - sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape], - dp_dims=([0, 0, 0, None], [0]), - tp_dims=([2, 3, None, None], [2]), - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name) - sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0}) - - inputs_ = tuple( - jnp.reshape(x, new_shape) if x is not None else None - for x, new_shape in zip(inputs, sharding_meta.input_shapes)) - - partial_cross_fused_attn = partial(_cross_fused_attn, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes, - sharding_meta.out_axes, sharding_meta.axis_resources, inputs_) - - output = jnp.reshape(output_, sharding_meta.output_shapes) + output = _cross_fused_attn(q, + kv, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) return output @@ -240,54 +147,40 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float, is_training: bool): - output, _ = _cross_fused_attn_fwd(q, - kv, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output, _ = _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training) return output -def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - - q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) - q_cu_seqlen = jnp.cumsum(q_seqlen) - q_cu_seqlen = jnp.hstack((0, q_cu_seqlen)) +def _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): - kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32) - kv_cu_seqlen = jnp.cumsum(kv_seqlen) - kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen)) + q_squeezed_mask = mask[:, :, :, 0] + kv_squeezed_mask = mask[:, :, 0, :] output, softmax_aux = cross_fused_attn_fwd(q, kv, - q_cu_seqlen, - kv_cu_seqlen, + q_squeezed_mask, + kv_squeezed_mask, seed, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) - return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen) - + return output, (softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask) -def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - is_training, ctx, grad): - softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx - doutput = grad +def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, + is_training, ctx, dz): + softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask = ctx grad_q, grad_kv = cross_fused_attn_bwd(q, kv, softmax_aux, - doutput, - q_cu_seqlen, - kv_cu_seqlen, + dz, + q_squeezed_mask, + kv_squeezed_mask, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, scaling_factor=scaling_factor, @@ -297,4 +190,4 @@ def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropou return grad_q, grad_kv, None, None -_cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd) +_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule) diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index ebac282c1a..0d71bcd7bc 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -3,25 +3,15 @@ # See LICENSE for license information. """JAX layernorm modules""" -from typing import Tuple, Sequence -from functools import partial, reduce -import operator +from functools import partial import jax import jax.numpy as jnp -from transformer_engine_jax import DType as TEDType -from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype -from .cpp_extensions import transpose +from .cpp_extensions import cast_transpose, transpose from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd -from .fp8 import FP8Helper, FP8GemmPackage -from .sharding import ShardingType, get_elementwise_sharding_meta -from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta -from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources -from .sharding import xmap_runner, extend_fsdp_sharding_meta - -jax.config.update('experimental_xmap_spmd_lowering', True) -jax.config.update('experimental_xmap_spmd_lowering_manual', True) +from .dot import fp8_dot_impl +from .fp8 import FP8Helper, FP8MetaPackage def canonicalize_layernorm_type(x): @@ -38,421 +28,241 @@ def layernorm(inputs: jnp.ndarray, beta: jnp.ndarray, layernorm_type: str, zero_centered_gamma: bool = False, - epsilon: float = 1e-6, - sharding_type: ShardingType = ShardingType.SINGLE, - dp_dim_index: int = 0): + epsilon: float = 1e-6): """ - Layernorm wrapper + LN/RMSNorm wrapper + Only support layernorm_type in ['layernorm', 'rmsnorm'] """ - assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ - "layernorm does not support row-split tensor parallelism currently." - - layernorm_type = canonicalize_layernorm_type(layernorm_type) - if layernorm_type == 'rmsnorm': - assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - - if sharding_type is ShardingType.SINGLE: - output = _layernorm(inputs, - gamma, - beta, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon, - sharding_type=sharding_type, - dp_axis_name="", - fsdp_axis_name="") - else: - dp_axis_name = "batch" - tp_axis_name = "model" - sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape, - dp_dim_index, dp_axis_name, tp_axis_name) - - sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index}) - inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input - gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1]) # 1 for gamma - beta_ = beta - beta_in_axis = {} - if beta_ is not None: - beta_ = jnp.reshape(beta_, sharding_meta.input_shapes[1]) # 1 for beta - beta_in_axis = sharding_meta.in_axes[1] - - in_axes = (*sharding_meta.in_axes, beta_in_axis) - - partial_ln = partial(_layernorm, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon, - sharding_type=sharding_type, - dp_axis_name=dp_axis_name, - fsdp_axis_name=fsdp_axis_name) - - output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes, - sharding_meta.axis_resources, (inputs_, gamma_, beta_)) - - output = jnp.reshape(output, sharding_meta.output_shapes[0]) - + output = _layernorm(inputs, + gamma, + beta, + layernorm_type=layernorm_type, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) -def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type, - dp_axis_name, fsdp_axis_name): - output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, - sharding_type, dp_axis_name, fsdp_axis_name) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) +def _layernorm(x, + gamma, + beta, + layernorm_type: str, + zero_centered_gamma: bool = False, + epsilon: float = 1e-6): + output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon) return output -def _layernorm_fwd( - x, - gamma, - beta, - layernorm_type, - zero_centered_gamma, - epsilon, - sharding_type, # pylint: disable=unused-argument - dp_axis_name, # pylint: disable=unused-argument - fsdp_axis_name # pylint: disable=unused-argument -): +def _layernorm_fwd_rule(x, + gamma, + beta, + layernorm_type: str, + zero_centered_gamma: bool = False, + epsilon: float = 1e-6): + layernorm_type = canonicalize_layernorm_type(layernorm_type) if layernorm_type == 'layernorm': output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon) - else: + elif layernorm_type == 'rmsnorm': assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" output, rsigma = rmsnorm_fwd(x, gamma, epsilon) mu = None - return output, (mu, rsigma, x, gamma) - + else: + raise ValueError(f"{layernorm_type=} is not supported.") + return output, (x, mu, rsigma, gamma) -def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name, - fsdp_axis_name, ctx, g): - mu, rsigma, x, gamma = ctx +def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz): + x, mu, rsigma, gamma = ctx if layernorm_type == 'layernorm': - grad_input, grad_gamma, grad_beta = layernorm_bwd(g, - mu, - rsigma, - x, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) - else: + dx, dgamma, dbeta = layernorm_bwd(dz, + x, + mu, + rsigma, + gamma, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) + elif layernorm_type == 'rmsnorm': assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon) - grad_beta = None - - if is_dp_enabled(sharding_type.value[0]): - grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name) - if grad_beta is not None: - grad_beta = jax.lax.psum(grad_beta, dp_axis_name) - if len(fsdp_axis_name) > 0: - grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name) - if grad_beta is not None: - grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name) + dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon) + dbeta = None + else: + raise ValueError(f"{layernorm_type=} is not supported.") - return grad_input, grad_gamma, grad_beta + return dx, dgamma, dbeta -_layernorm.defvjp(_layernorm_fwd, _layernorm_bwd) +_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule) -def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, +def layernorm_fp8_dot(x: jnp.ndarray, + kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, + fp8_meta_pkg: FP8MetaPackage, layernorm_type: str, - fwd_dtype: TEDType, - bwd_dtype: TEDType, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), zero_centered_gamma: bool = False, - epsilon: float = 1e-6, - sharding_type: ShardingType = ShardingType.SINGLE, - dp_dim_index: int = 0) -> jnp.ndarray: + epsilon: float = 1e-6) -> jnp.ndarray: """ - LN + fp8 dot fusion wrapper + Layernorm + FP8 GEMM """ - assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ - "layernorm_fp8_dot does not support row-split tensor parallelism currently." - - layernorm_type = canonicalize_layernorm_type(layernorm_type) - if layernorm_type == 'rmsnorm': - assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - - assert fp8_gemm_pkg.num_of_gemm == 1 - inputs = fp8_gemm_pkg.inputs - kernel = fp8_gemm_pkg.kernels[0] - fp8_max = fp8_gemm_pkg.fp8_max - amax = fp8_gemm_pkg.amax - scale = fp8_gemm_pkg.scale - scale_inv = fp8_gemm_pkg.scale_inv - - if sharding_type is ShardingType.SINGLE: - output = _layernorm_fp8_dot(inputs, - kernel, - gamma, - beta, - fp8_max, - amax, - scale, - scale_inv, - layernorm_type, - fwd_dtype, - bwd_dtype, - contracting_dims, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon, - sharding_type=sharding_type, - dp_axis_name="", - tp_axis_name="", - fsdp_axis_name="") - else: - dp_axis_name = "batch" - tp_axis_name = "model" - - ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape, - dp_dim_index, dp_axis_name, tp_axis_name) - ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index}) - inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input - gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1]) # 1 for gamma - beta_ = beta - beta_in_axis = {} - if beta_ is not None: - beta_ = jnp.reshape(beta_, ln_sharding_meta.input_shapes[1]) # 1 for beta - beta_in_axis = ln_sharding_meta.in_axes[1] - - kernel_tp_index = None - # TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme - if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL): - kernel_tp_index = len(kernel.shape) - 1 - elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW): - kernel_tp_index = 0 - - input_tp_index = len(inputs.shape) - 1 - dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape, - dp_dim_index, input_tp_index, kernel_tp_index, - contracting_dims, dp_axis_name, tp_axis_name) - dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(dot_sharding_meta, - {0: dp_dim_index}) - kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1]) # 1 for kernel - - num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv - fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind, - dp_axis_name, tp_axis_name) - - axis_resource = merge_axis_resources([ - ln_sharding_meta.axis_resources, dot_sharding_meta.axis_resources, - fp8_sharding_meta.axis_resources - ]) - - partial_ln_fp8_dot = partial(_layernorm_fp8_dot, - layernorm_type=layernorm_type, - fwd_dtype=fwd_dtype, - bwd_dtype=bwd_dtype, - contracting_dims=contracting_dims, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon, - sharding_type=sharding_type, - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name, - fsdp_axis_name=fsdp_axis_name) - - # input, kernel, gamma, beta, fp8_metas - in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1], - ln_sharding_meta.in_axes[1], beta_in_axis, *fp8_sharding_meta.in_axes) - - output = xmap_runner(partial_ln_fp8_dot, in_axes, dot_sharding_meta.out_axes, axis_resource, - (inputs_, kernel_, gamma_, beta_, fp8_max, amax, scale, scale_inv)) - - output = jnp.reshape(output, dot_sharding_meta.output_shapes[0]) + fp8_max = fp8_meta_pkg.fp8_max + amax = fp8_meta_pkg.amax + scale = fp8_meta_pkg.scale + scale_inv = fp8_meta_pkg.scale_inv + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv, + layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon) return output -@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) -def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, - beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, - fwd_dtype: TEDType, bwd_dtype: TEDType, - contracting_dims: Tuple[Sequence[int], Sequence[int]], - zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType, - dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str) -> jnp.ndarray: - output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale, - scale_inv, layernorm_type, fwd_dtype, bwd_dtype, - contracting_dims, zero_centered_gamma, epsilon, - sharding_type, dp_axis_name, tp_axis_name, fsdp_axis_name) +@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12)) +def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, + fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, + scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype, + bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float): + output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv, + layernorm_type, fwd_dtype, bwd_dtype, + zero_centered_gamma, epsilon) return output -def _layernorm_fp8_dot_fwd( - inputs, +def _layernorm_fp8_dot_fwd_rule( + x, kernel, gamma, beta, - fp8_maxs, + fp8_max, amax, scale, scale_inv, layernorm_type, fwd_dtype, bwd_dtype, # pylint: disable=unused-argument - contracting_dims, zero_centered_gamma, - epsilon, - sharding_type, - dp_axis_name, # pylint: disable=unused-argument - tp_axis_name, - fsdp_axis_name): # pylint: disable=unused-argument - - lhs_contracting_dims, rhs_contracting_dims = contracting_dims - input_shape_pre = inputs.shape[:min(lhs_contracting_dims)] - input_shape_suf = inputs.shape[min(lhs_contracting_dims):] - kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1] - kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:] - input_contracting_size = reduce(operator.mul, input_shape_suf) - kernel_contracting_size = reduce(operator.mul, kernel_shape_pre) - assert input_contracting_size == kernel_contracting_size + epsilon): + + x_contracting_dims = (len(x.shape) - 1,) + k_contracting_dims = (0,) + assert x.shape[-1] == kernel.shape[0] amax = FP8Helper.update_amax_history(amax) - gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) + gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) + + x_amax = amax[gemm_x_idx, 0:1] + x_scale = scale[gemm_x_idx] + x_scale_inv = scale_inv[gemm_x_idx] - input_amax = amax[gemm_input_idx, 0:1] - input_scale = scale[gemm_input_idx] - input_scale_inv = scale_inv[gemm_input_idx] if layernorm_type == 'layernorm': - ln_out, mu, rsigma, input_amax = layernorm_fwd_fp8(inputs, - gamma, - beta, - input_amax, - input_scale, - input_scale_inv, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8( + x, + gamma, + beta, + x_amax, + x_scale, + x_scale_inv, + out_dtype=fwd_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) else: assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs, - gamma, - input_amax, - input_scale, - input_scale_inv, - epsilon=epsilon) + ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x, + gamma, + x_amax, + x_scale, + x_scale_inv, + out_dtype=fwd_dtype, + epsilon=epsilon) mu = None - assert inputs.shape == ln_out.shape - ln_out_ = jnp.reshape(ln_out, (-1, input_contracting_size)) - kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1)) + assert x.shape == ln_out.shape kernel_amax = amax[gemm_kernel_idx, 0:1] kernel_scale = scale[gemm_kernel_idx] kernel_scale_inv = scale_inv[gemm_kernel_idx] - kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale, - kernel_scale_inv, fwd_dtype) - output = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, ln_out_, input_scale_inv, - fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP) + # Kernel in (hidden_in, hidden_out...) + casted_kerenl, casted_kerenl_t, updated_kernel_amax = \ + cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype, + static_axis_boundary=-1, transpose_axis_boundary=1) - if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW): - output = jax.lax.psum(output, tp_axis_name) + # (batch..., hidden_in) x (hidden_in, hidden_out...) + kt_contracting_dims = (kernel.ndim - 1,) + output = fp8_dot_impl(ln_out, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype, + (x_contracting_dims, kt_contracting_dims)) - # (input_shape_pre, input_shape_suf) - # x (kernel_shape_pre, kernel_shape_suf) - # = (input_shape_pre, kernel_shape_suf) - output_shape = input_shape_pre + kernel_shape_suf - output = jnp.reshape(output, output_shape) + ctx = (ln_out, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax, + updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims, + k_contracting_dims) - ctx = (ln_out_, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax, - inputs.shape, kernel.shape, mu, rsigma, inputs, gamma) return output, ctx -def _layernorm_fp8_dot_bwd( +def _layernorm_fp8_dot_bwd_rule( layernorm_type, - fwd_dtype, + fwd_dtype, # pylint: disable=unused-argument bwd_dtype, - contracting_dims, # pylint: disable=unused-argument zero_centered_gamma, epsilon, - sharding_type, - dp_axis_name, - tp_axis_name, - fsdp_axis_name, ctx, - g): - ln_out_, kernel_cast, \ - fp8_maxs, amax, scale, scale_inv, \ - input_amax, kernel_amax, \ - inputs_shape, kernel_shape, \ - mu, rsigma, inputs, gamma = ctx + grad): + ln_out_, casted_kerenl, fp8_max, amax, scale, scale_inv, \ + updated_x_amax, updated_kernel_amax, \ + x_shape, kernel_shape, mu, rsigma, x, gamma, \ + x_contracting_dims, k_contracting_dims = ctx - gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = \ - FP8Helper.get_fp8_meta_indices(0) + ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1) + + gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0) grad_amax = amax[gemm_grad_idx, 0:1] grad_scale = scale[gemm_grad_idx] grad_scale_inv = scale_inv[gemm_grad_idx] - ln_out_trans = transpose(ln_out_, fwd_dtype) - g = jnp.reshape(g, (ln_out_trans.shape[1], -1)) - - # cast and transpose the grad_output - grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv, - bwd_dtype) + casted_grad, casted_grad_t, updated_grad_amax = \ + cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, + static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims)) - input_scale_inv = scale_inv[gemm_input_idx] - wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True, ln_out_trans, input_scale_inv, - fwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD) + xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape))) + gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim)) + x_scale_inv = scale_inv[gemm_x_idx] + wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype, + (xt_constracting_dim, gt_constracting_dim)) + g_constracting_dim = tuple( + range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim)) + k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape))) kernel_scale_inv = scale_inv[gemm_kernel_idx] - dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, - bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) - - dgrad = jnp.reshape(dgrad, inputs_shape) - - if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL): - dgrad = jax.lax.psum(dgrad, tp_axis_name) + dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype, + (g_constracting_dim, k_constracting_dim)) if layernorm_type == 'layernorm': - grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad, - mu, - rsigma, - inputs, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + dx, dgamma, dbeta = layernorm_bwd(dgrad, + x, + mu, + rsigma, + gamma, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) else: assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon) - grad_beta = None - - amax = amax.at[gemm_input_idx, 0].set(input_amax[0]) - amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0]) - amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0]) - - if is_dp_enabled(sharding_type.value[0]): - wgrad = jax.lax.psum(wgrad, dp_axis_name) - grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name) - if grad_beta is not None: - grad_beta = jax.lax.psum(grad_beta, dp_axis_name) - amax = jax.lax.pmax(amax, dp_axis_name) + dx, dgamma = rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon) + dbeta = None - if len(fsdp_axis_name) > 0: - wgrad = jax.lax.psum(wgrad, fsdp_axis_name) - grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name) - if grad_beta is not None: - grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name) - amax = jax.lax.pmax(amax, fsdp_axis_name) + amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0]) + amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0]) + amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0]) - if is_tp_enabled(sharding_type.value[0]): - amax = jax.lax.pmax(amax, tp_axis_name) + scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) - wgrad = jnp.reshape(wgrad, kernel_shape) - return grad_input, wgrad, \ - grad_gamma, grad_beta, \ - fp8_maxs, amax, scale, scale_inv + return dx, wgrad, \ + dgamma, dbeta, \ + fp8_max, amax, scale, scale_inv -_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd, _layernorm_fp8_dot_bwd) +_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule) diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index f28855ee39..deeaab7901 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -3,462 +3,307 @@ # See LICENSE for license information. """JAX MLP modules""" -from typing import Tuple, Sequence, Union, Callable -from functools import partial, reduce -import operator +from typing import List +from functools import partial import jax import jax.numpy as jnp -from jax.interpreters import pxla -from transformer_engine_jax import DType as TEDType -from .cpp_extensions import jax_dtype_to_te_dtype from .cpp_extensions import transpose, cast_transpose from .cpp_extensions import gated_gelu, gated_gelu_fp8 from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd -from .cpp_extensions import gemm -from .sharding import MajorShardingType, ShardingType -from .sharding import get_elementwise_sharding_meta -from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta -from .sharding import merge_axis_resources, infer_sharding_type -from .sharding import xmap_runner, extend_fsdp_sharding_meta +from .dot import fp8_dot_impl from .layernorm import canonicalize_layernorm_type -from .fp8 import FP8Helper, FP8GemmPackage +from .fp8 import FP8Helper, FP8MetaPackage -jax.config.update('experimental_xmap_spmd_lowering', True) -jax.config.update('experimental_xmap_spmd_lowering_manual', True) -thread_resources = pxla.thread_resources - - -def geglu( - inputs: jnp.ndarray, - contracting_dims: Sequence[int] = (-1,), - sharding_type: ShardingType = ShardingType.SINGLE, - dp_dim_index: int = 0, # pylint: disable=unused-argument -): +def geglu(x: jnp.ndarray): """ Gated gelu """ - input_shape_suf_size = reduce(operator.mul, inputs.shape[min(contracting_dims):]) - assert input_shape_suf_size % 2 == 0 - output_shape = (*inputs.shape[:min(contracting_dims)], input_shape_suf_size // 2) - - if sharding_type is ShardingType.SINGLE: - output = _geglu(inputs, contracting_dims) - else: - dp_axis_name = "batch" - tp_axis_name = "model" - - sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, None, - dp_dim_index, dp_axis_name, tp_axis_name) - sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index}) - - inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input - - partial_geglu = partial(_geglu, contracting_dims=contracting_dims) + assert x.shape[-2] == 2 # Linear + GeLU - output = xmap_runner(partial_geglu, sharding_meta.in_axes, sharding_meta.out_axes, - sharding_meta.axis_resources, (inputs_,)) + output = _geglu(x) - output = jnp.reshape(output, output_shape) return output -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def _geglu(inputs: jnp.ndarray, contracting_dims: Sequence[int] = (-1,)): +@partial(jax.custom_vjp) +def _geglu(x: jnp.ndarray): - geglu_output, _ = _geglu_fwd(inputs, contracting_dims) + geglu_output, _ = _geglu_fwd_rule(x) return geglu_output -def _geglu_fwd(inputs, contracting_dims): - inputs_real_shape = (*inputs.shape[:min(contracting_dims)], - reduce(operator.mul, inputs.shape[min(contracting_dims):])) - inputs_ = jnp.reshape(inputs, inputs_real_shape) - geglu_output = gated_gelu(inputs_) - geglu_output = jnp.expand_dims(geglu_output, min(contracting_dims)) - return geglu_output, (inputs_, inputs.shape) +def _geglu_fwd_rule(x): + geglu_output = gated_gelu(x) + return geglu_output, (x,) -def _geglu_bwd(contracting_dims, ctx, g): - inputs_, inputs_shape = ctx - g = jnp.squeeze(g, min(contracting_dims)) - assert inputs_.dtype == g.dtype +def _geglu_bwd_rule(ctx, g): + x, = ctx + assert x.dtype == g.dtype - dgelu = dgated_gelu(g, inputs_) - dgelu = jnp.reshape(dgelu, inputs_shape) + dgelu = dgated_gelu(g, x) + dgelu = jnp.reshape(dgelu, x.shape) return (dgelu,) -_geglu.defvjp(_geglu_fwd, _geglu_bwd) +_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule) -def fp8_ln_mlp( - fp8_gemm_pkg: FP8GemmPackage, - ln_scale: jnp.ndarray, - ln_bias: jnp.ndarray, - layernorm_type: str, - fwd_dtype: TEDType, - bwd_dtype: TEDType, - zero_centered_gamma: bool = False, - epsilon: float = 1e-6, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), - major_sharding_type: MajorShardingType = MajorShardingType.SINGLE, - dp_dim_index: int = 0, # pylint: disable=unused-argument - activations: Sequence[Union[str, Callable]] = ('gelu', 'linear') -) -> jnp.ndarray: +def layernrom_geglu_fp8_mlp(x: jnp.ndarray, + gamma: jnp.ndarray, + beta: jnp.ndarray, + kernels: List[jnp.ndarray], + fp8_gemm_pkg: FP8MetaPackage, + layernorm_type: str, + zero_centered_gamma: bool = False, + epsilon: float = 1e-6) -> jnp.ndarray: """ - FP8 layernorm MLP wrapper - (LN + Dense + act + Dense) + Layernorm + GEMM1 + GeGLU + GEMM2 """ - assert fp8_gemm_pkg.num_of_gemm == 2 - inputs = fp8_gemm_pkg.inputs - kernel_1 = fp8_gemm_pkg.kernels[0] - kernel_2 = fp8_gemm_pkg.kernels[1] + + assert len(kernels) == 2 + assert fp8_gemm_pkg.num_of_gemm == len(kernels) + + kernel_1 = kernels[0] + kernel_2 = kernels[1] fp8_max = fp8_gemm_pkg.fp8_max amax = fp8_gemm_pkg.amax scale = fp8_gemm_pkg.scale scale_inv = fp8_gemm_pkg.scale_inv + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + layernorm_type = canonicalize_layernorm_type(layernorm_type) if layernorm_type == 'rmsnorm': - assert ln_bias is None, "ln_bias should be None if layernorm_type is 'rmsnorm'" + assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - assert activations == ('gelu', 'linear') - if major_sharding_type is MajorShardingType.SINGLE: - res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale, - scale_inv, layernorm_type, activations, zero_centered_gamma, epsilon, - fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "", "") - else: - dp_axis_name = "batch" - tp_axis_name = "model" - - first_part_st, second_part_st = infer_sharding_type(major_sharding_type) - - ln_sharding_meta = get_elementwise_sharding_meta(first_part_st, inputs.shape, - ln_scale.shape, dp_dim_index, dp_axis_name, - tp_axis_name) - ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index}) - - input_tp_index = len(inputs.shape) - 1 - first_dot_sharding_meta = get_dot_sharding_meta(first_part_st, inputs.shape, kernel_1.shape, - dp_dim_index, input_tp_index, 2, - contracting_dims, dp_axis_name, - tp_axis_name) - first_dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta( - first_dot_sharding_meta, {0: dp_dim_index}) - second_input_shape = (*first_dot_sharding_meta.output_shapes[0][:-2], - first_dot_sharding_meta.output_shapes[0][-1]) - second_dot_sharding_meta = get_dot_sharding_meta(second_part_st, second_input_shape, - kernel_2.shape, dp_dim_index, - len(second_input_shape) - 1, 0, - contracting_dims, dp_axis_name, - tp_axis_name) - second_dot_sharding_meta, _ = extend_fsdp_sharding_meta(second_dot_sharding_meta, - {0: dp_dim_index}) - - num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv - fp8_sharding_meta = get_fp8_meta_sharding_meta(first_part_st, num_of_fp8_meta_kind, - dp_axis_name, tp_axis_name) - - inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input - ln_scale_ = jnp.reshape(ln_scale, ln_sharding_meta.input_shapes[1]) # 1 for gamma - ln_bias_ = ln_bias - ln_bias_in_axis = {} - if ln_bias_ is not None: - ln_bias_ = jnp.reshape(ln_bias_, ln_sharding_meta.input_shapes[1]) # 1 for beta - ln_bias_in_axis = ln_sharding_meta.in_axes[1] - kernel_1_ = jnp.reshape(kernel_1, first_dot_sharding_meta.input_shapes[1]) # 1 for kernel - kernel_2_ = jnp.reshape(kernel_2, - second_dot_sharding_meta.input_shapes[1]) # 1 for kernel - - axis_resource = merge_axis_resources([ - ln_sharding_meta.axis_resources, first_dot_sharding_meta.axis_resources, - second_dot_sharding_meta.axis_resources, fp8_sharding_meta.axis_resources - ]) - - partial_fp8_mlp = partial(_fp8_mlp, - layernorm_type=layernorm_type, - activations=activations, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon, - fwd_dtype=fwd_dtype, - bwd_dtype=bwd_dtype, - contracting_dims=contracting_dims, - major_sharding_type=major_sharding_type, - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name, - fsdp_axis_name=fsdp_axis_name) - in_axes = (ln_sharding_meta.in_axes[0], ln_sharding_meta.in_axes[1], ln_bias_in_axis, - first_dot_sharding_meta.in_axes[1], second_dot_sharding_meta.in_axes[1], - *fp8_sharding_meta.in_axes) - - res = xmap_runner( - partial_fp8_mlp, in_axes, second_dot_sharding_meta.out_axes, axis_resource, - (inputs_, ln_scale_, ln_bias_, kernel_1_, kernel_2_, fp8_max, amax, scale, scale_inv)) - res = jnp.reshape(res, second_dot_sharding_meta.output_shapes[0]) - - return res - - -@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) -def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, - kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, - activations: Sequence[Union[str, Callable]], zero_centered_gamma: bool, epsilon: float, - fwd_dtype: TEDType, bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], - Sequence[int]], - major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str, - fsdp_axis_name: str): - res, _ = _fp8_mlp_fwd(inputs, - ln_scale, - ln_bias, - kernel_1, - kernel_2, - fp8_maxs, - amax, - scale, - scale_inv, - layernorm_type, - activations, - zero_centered_gamma, - epsilon, - fwd_dtype, - bwd_dtype, - contracting_dims=contracting_dims, - major_sharding_type=major_sharding_type, - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name, - fsdp_axis_name=fsdp_axis_name) - return res - - -def _fp8_mlp_fwd( - inputs, + output = _layernrom_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, + scale_inv, fwd_dtype, bwd_dtype, layernorm_type, + zero_centered_gamma, epsilon) + return output + + +@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13)) +def _layernrom_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, + kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray, + amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, + fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, + zero_centered_gamma: bool, epsilon: float): + output, _ = _layernrom_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, + scale, scale_inv, fwd_dtype, bwd_dtype, + layernorm_type, zero_centered_gamma, epsilon) + return output + + +def _layernrom_geglu_fp8_mlp_fwd_rule( + x, gamma, beta, kernel_1, kernel_2, - fp8_maxs, + fp8_max, amax, scale, scale_inv, - layernorm_type, - activations, - zero_centered_gamma, - epsilon, fwd_dtype, bwd_dtype, # pylint: disable=unused-argument - contracting_dims, - major_sharding_type, - dp_axis_name, # pylint: disable=unused-argument - tp_axis_name, - fsdp_axis_name): # pylint: disable=unused-argument - if activations != ('gelu', 'linear'): - raise NotImplementedError("activations only support ('gelu', 'linear') for now.") - lhs_contracting_dims, rhs_contracting_dims = contracting_dims - input_shape_pre = inputs.shape[:min(lhs_contracting_dims)] - input_shape_suf = inputs.shape[min(lhs_contracting_dims):] - kernel_1_shape_pre = kernel_1.shape[:max(rhs_contracting_dims) + 1] - kernel_1_shape_suf = kernel_1.shape[max(rhs_contracting_dims) + 1:] - kernel_2_shape_pre = kernel_2.shape[:max(rhs_contracting_dims) + 1] - kernel_2_shape_suf = kernel_2.shape[max(rhs_contracting_dims) + 1:] - input_contracting_size = reduce(operator.mul, input_shape_suf) - kernel_1_pre_size = reduce(operator.mul, kernel_1_shape_pre) - kernel_1_suf_size = reduce(operator.mul, kernel_1_shape_suf) - kernel_2_pre_size = reduce(operator.mul, kernel_2_shape_pre) - assert input_contracting_size == kernel_1_pre_size - assert kernel_1_suf_size == kernel_2_pre_size * len(activations) - inputs_ = jnp.reshape(inputs, (-1, input_contracting_size)) - kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1)) - kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1)) + layernorm_type, + zero_centered_gamma, + epsilon): + + # x should be in shape of (batch..., hidden) + # Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out) + # Kernel_2 should be in shape of (Hidden_in, Hidden_out) + assert len(kernel_1.shape) == 3 + assert kernel_1.shape[-2] == 2 + assert len(kernel_2.shape) == 2 + + x_contracting_dims = (len(x.shape) - 1,) + xt_batch_dims = tuple(range(1, x.ndim)) + + assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0] + assert kernel_1.shape[-1] == kernel_2.shape[0] amax = FP8Helper.update_amax_history(amax) - gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) + gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) + + x_amax = amax[gemm1_x_idx, 0:1] + x_scale = scale[gemm1_x_idx] + x_scale_inv = scale_inv[gemm1_x_idx] - input_amax = amax[gemm1_input_idx, 0:1] - input_scale = scale[gemm1_input_idx] - input_scale_inv = scale_inv[gemm1_input_idx] if layernorm_type == 'layernorm': - ln_out, mu, rsigma, ln_out_amax = layernorm_fwd_fp8(inputs_, - gamma, - beta, - input_amax, - input_scale, - input_scale_inv, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8( + x, + gamma, + beta, + x_amax, + x_scale, + x_scale_inv, + out_dtype=fwd_dtype, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) else: assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - ln_out, rsigma, ln_out_amax = rmsnorm_fwd_fp8(inputs_, - gamma, - input_amax, - input_scale, - input_scale_inv, - epsilon=epsilon) + ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x, + gamma, + x_amax, + x_scale, + x_scale_inv, + out_dtype=fwd_dtype, + epsilon=epsilon) mu = None + assert x.shape == ln_out.shape + kernel_1_amax = amax[gemm1_kernel_idx, 0:1] kernel_1_scale = scale[gemm1_kernel_idx] kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - kernel_1_cast, kernel_1_cast_trans, kernel_1_amax = cast_transpose( - kernel_1_, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype) - dense_1_output = gemm(kernel_1_cast_trans, kernel_1_scale_inv, fwd_dtype, True, ln_out, - scale_inv[gemm1_input_idx], fwd_dtype, False, - jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP) - gemm2_input_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) + casted_kerenl_1, casted_kerenl_1_t, updated_kernel_1_amax = \ + cast_transpose(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype, + static_axis_boundary=-1, transpose_axis_boundary=-2) + + # (batch..., hidden_in) x (2, hidden_out, hidden_in) + dot_1_output = fp8_dot_impl(ln_out, casted_kerenl_1_t, x_scale_inv, kernel_1_scale_inv, x.dtype, + (x_contracting_dims, (2,))) + + gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) + + geglu_out_amax = amax[gemm2_x_idx, 0:1] + geglu_out_scale = scale[gemm2_x_idx] + geglu_out_scale_inv = scale_inv[gemm2_x_idx] + + # (batch..., hidden_in) -> (batch..., hidden) + casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax, + geglu_out_scale, geglu_out_scale_inv, + fwd_dtype) kernel_2_amax = amax[gemm2_kernel_idx, 0:1] kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] - kernel_2_cast, kernel_2_cast_trans, kernel_2_amax = cast_transpose( - kernel_2_, kernel_2_amax, kernel_2_scale, kernel_2_scale_inv, fwd_dtype) - - dense_1_out_amax = amax[gemm2_input_idx, 0:1] - dense_1_out_scale = scale[gemm2_input_idx] - dense_1_out_scale_inv = scale_inv[gemm2_input_idx] - gated_gelu_output_cast, gated_gelu_amax = gated_gelu_fp8(dense_1_output, dense_1_out_amax, - dense_1_out_scale, - dense_1_out_scale_inv, fwd_dtype) - res = gemm(kernel_2_cast_trans, kernel_2_scale_inv, fwd_dtype, True, - gated_gelu_output_cast, dense_1_out_scale_inv, fwd_dtype, False, - jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP) - - if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP): - res = jax.lax.psum(res, tp_axis_name) - - # (input_shape_pre, input_shape_suf) - # x (kernel_1_shape_pre, kernel_1_shape_suf) - # x (kernel_2_shape_pre, kernel_2_shape_suf) - # = (input_shape_pre, kernel_2_shape_suf) - output_shape = input_shape_pre + kernel_2_shape_suf - res = jnp.reshape(res, output_shape) - - ctx = (inputs_, ln_out, mu, rsigma, gamma, dense_1_output, gated_gelu_output_cast, - kernel_1_cast, kernel_2_cast, fp8_maxs, amax, scale, scale_inv, ln_out_amax, - gated_gelu_amax, kernel_1_amax, kernel_2_amax, inputs.shape, kernel_1.shape, - kernel_2.shape) - - return res, ctx - - -def _fp8_mlp_bwd( + + casted_kerenl_2, casted_kerenl_2_t, updated_kernel_2_amax = \ + cast_transpose(kernel_2, kernel_2_amax, kernel_2_scale, kernel_2_scale_inv, fwd_dtype, + static_axis_boundary=-1, transpose_axis_boundary=-1) + + # (batch..., hidden_in) x (hidden_out, hidden_in) + dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kerenl_2_t, geglu_out_scale_inv, + kernel_2_scale_inv, x.dtype, (x_contracting_dims, (1,))) + + ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kerenl_1, + casted_kerenl_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax, + updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims) + + return dot_2_output, ctx + + +def _layernrom_geglu_fp8_mlp_bwd_rule( + fwd_dtype, # pylint: disable=unused-argument + bwd_dtype, layernorm_type, - activations, # pylint: disable=unused-argument zero_centered_gamma, epsilon, - fwd_dtype, - bwd_dtype, - contracting_dims, # pylint: disable=unused-argument - major_sharding_type, - dp_axis_name, - tp_axis_name, - fsdp_axis_name, ctx, - g): - inputs_, ln_out, mu, rsigma, gamma, \ - dense_1_output, gated_gelu_output_cast, \ - kernel_1_cast, kernel_2_cast, \ - fp8_maxs, amax, scale, scale_inv, \ - ln_out_amax, gated_gelu_amax, kernel_1_amax, kernel_2_amax, \ - input_shape, kernel_1_shape, kernel_2_shape = ctx - - g = jnp.reshape(g, (ln_out.shape[0], -1)) + grad): + x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \ + casted_kerenl_1, casted_kerenl_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ + updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ + x_contracting_dims, xt_batch_dims = ctx - gemm2_input_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) + gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) grad_amax = amax[gemm2_grad_idx, 0:1] grad_scale = scale[gemm2_grad_idx] grad_scale_inv = scale_inv[gemm2_grad_idx] - grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv, - bwd_dtype) - gated_gelu_output_cast_trans = transpose(gated_gelu_output_cast, fwd_dtype) + casted_grad, casted_grad_t, updated_grad_amax = \ + cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, + static_axis_boundary=-1, transpose_axis_boundary=-1) + + casted_geglu_out_t = transpose(casted_geglu_out, + static_axis_boundary=-1, + transpose_axis_boundary=-1) - gemm2_input_scale_inv = scale_inv[gemm2_input_idx] - wgrad_2 = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True, - gated_gelu_output_cast_trans, gemm2_input_scale_inv, fwd_dtype, False, - jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD) + # (hidden, batch...,) x (hidden, batch...) + gemm2_x_scale_inv = scale_inv[gemm2_x_idx] + wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv, + grad.dtype, (xt_batch_dims, xt_batch_dims)) + + # (batch..., hidden_out) x (hidden_in, hidden_out) kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] - dgrad_2 = gemm(kernel_2_cast, kernel_2_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, - bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) + dgrad_2 = fp8_dot_impl(casted_grad, casted_kerenl_2, grad_scale_inv, kernel_2_scale_inv, + grad.dtype, (x_contracting_dims, (1,))) + + gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) + + dgeglu_amax = amax[gemm1_grad_idx, 0:1] + dgeglu_scale = scale[gemm1_grad_idx] + dgeglu_scale_inv = scale_inv[gemm1_grad_idx] - gemm1_input_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) + casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose( + dgrad_2, + dot_1_output, + dgeglu_amax, + dgeglu_scale, + dgeglu_scale_inv, + bwd_dtype, + static_axis_boundary=-1) - dgrad_2_amax = amax[gemm1_grad_idx, 0:1] - dgrad_2_scale = scale[gemm1_grad_idx] - dgrad_2_scale_inv = scale_inv[gemm1_grad_idx] - dgelu, dgelu_trans, dgelu_amax = dgated_gelu_cast_transpose(dgrad_2, dense_1_output, - dgrad_2_amax, dgrad_2_scale, - dgrad_2_scale_inv, bwd_dtype) - ln_out_trans = transpose(ln_out, fwd_dtype) + ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) - gemm1_input_scale_inv = scale_inv[gemm1_input_idx] - wgrad_1 = gemm(dgelu_trans, dgrad_2_scale_inv, bwd_dtype, True, - ln_out_trans, gemm1_input_scale_inv, fwd_dtype, False, - jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD) + # (hidden, batch...) x (2, hidden, batch...) + xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims) + gemm1_x_scale_inv = scale_inv[gemm1_x_idx] + wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv, + grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim)) + # (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out) + x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple( + i + 1 for i in x_contracting_dims) kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - dgrad_1 = gemm(kernel_1_cast, kernel_1_scale_inv, fwd_dtype, True, dgelu, dgrad_2_scale_inv, - bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) - if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP): - dgrad_1 = jax.lax.psum(dgrad_1, tp_axis_name) + dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kerenl_1, dgeglu_scale_inv, kernel_1_scale_inv, + grad.dtype, (x_contracting_dims_plus_act_dim, ( + 1, + 2, + ))) if layernorm_type == 'layernorm': - grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad_1, - mu, - rsigma, - inputs_, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) + dx, dgamma, dbeta = layernorm_bwd(dgrad_1, + x, + mu, + rsigma, + gamma, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon) else: assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon) - grad_beta = None - - amax = amax.at[gemm1_input_idx, 0].set(ln_out_amax[0]) - amax = amax.at[gemm1_kernel_idx, 0].set(kernel_1_amax[0]) - amax = amax.at[gemm1_grad_idx, 0].set(dgelu_amax[0]) - amax = amax.at[gemm2_input_idx, 0].set(gated_gelu_amax[0]) - amax = amax.at[gemm2_kernel_idx, 0].set(kernel_2_amax[0]) - amax = amax.at[gemm2_grad_idx, 0].set(grad_amax[0]) - - if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP): - wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name) - wgrad_2 = jax.lax.psum(wgrad_2, dp_axis_name) - grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name) - if grad_beta is not None: - grad_beta = jax.lax.psum(grad_beta, dp_axis_name) - amax = jax.lax.pmax(amax, dp_axis_name) - - if len(fsdp_axis_name) > 0: - wgrad_1 = jax.lax.psum(wgrad_1, fsdp_axis_name) - wgrad_2 = jax.lax.psum(wgrad_2, fsdp_axis_name) - grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name) - if grad_beta is not None: - grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name) - amax = jax.lax.pmax(amax, fsdp_axis_name) - - if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP): - amax = jax.lax.pmax(amax, tp_axis_name) - - grad_input = jnp.reshape(grad_input, input_shape) - wgrad_1 = jnp.reshape(wgrad_1, kernel_1_shape) - wgrad_2 = jnp.reshape(wgrad_2, kernel_2_shape) - return grad_input, grad_gamma, grad_beta, \ - wgrad_1, wgrad_2, \ - fp8_maxs, amax, scale, scale_inv - - -_fp8_mlp.defvjp(_fp8_mlp_fwd, _fp8_mlp_bwd) + dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon) + dbeta = None + + amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0]) + amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0]) + amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0]) + amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0]) + amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax[0]) + amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0]) + + scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) + + return dx, dgamma, dbeta, wgrad_1, wgrad_2, \ + fp8_max, amax, scale, scale_inv + + +_layernrom_geglu_fp8_mlp.defvjp(_layernrom_geglu_fp8_mlp_fwd_rule, + _layernrom_geglu_fp8_mlp_bwd_rule) diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index 2d07ee4b75..a7997939bd 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -49,7 +49,7 @@ def create_layer(self, name, flax_module_cls): fp8_collection_map = { FP8Helper.FP8_COLLECTION_NAME: [ WeightHParamsCollection.SKIP_LP_REGULARIZATION, - WeightHParamsCollection.NON_TRAINABLE, + WeightHParamsCollection.OVERWRITE_WITH_GRADIENT, WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION ] } @@ -92,8 +92,7 @@ def setup(self) -> None: "ln_bias", self.bias_init), bias_axes=self.bias_axes, dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - sharding_type=self.sharding_type) + transpose_batch_sequence=self.transpose_batch_sequence) self.create_layer("layer_norm", ln_cls) @@ -115,8 +114,7 @@ def setup(self) -> None: fused_softmax_cls = partial(Softmax, scale_factor=self.scale_factor, - softmax_type=self.softmax_type, - sharding_type=self.sharding_type) + softmax_type=self.softmax_type) self.create_layer("fused_softmax", fused_softmax_cls) @@ -151,8 +149,7 @@ def setup(self) -> None: bias_axes=self.bias_axes, axis=self.axis, dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - sharding_type=self.sharding_type) + transpose_batch_sequence=self.transpose_batch_sequence) self.create_layer("linear", dense_general_cls) @@ -208,8 +205,7 @@ def setup(self) -> None: axis=self.axis, dtype=self.dtype, transpose_batch_sequence=self.transpose_batch_sequence, - depth_scaling=self.depth_scaling, - sharding_type=self.sharding_type) + depth_scaling=self.depth_scaling) self.create_layer("ln_linear", ln_dense_general_cls) @@ -273,8 +269,7 @@ def setup(self) -> None: intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims, axis=self.axis, dtype=self.dtype, - transpose_batch_sequence=self.transpose_batch_sequence, - major_sharding_type=self.major_sharding_type) + transpose_batch_sequence=self.transpose_batch_sequence) self.create_layer("ln_mlp", ln_mlp_cls) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index a60e2f57ff..00ed5b2aac 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -8,17 +8,12 @@ from contextlib import contextmanager from dataclasses import dataclass from enum import Enum -from itertools import repeat -from typing import Union, Tuple, Dict, Callable, Sequence +from typing import Callable from jax.interpreters import pxla import jax import jax.numpy as jnp -from jax.experimental.maps import xmap from jax.sharding import PartitionSpec -jax.config.update('experimental_xmap_spmd_lowering', True) -jax.config.update('experimental_xmap_spmd_lowering_manual', True) - _PXLA_THREAD_RESOURCES = pxla.thread_resources @@ -29,6 +24,24 @@ def _get_mesh_info(resource: str): return mesh.shape[resource], resource +def get_all_mesh_axes(): + """ + Get all name of mesh axes + """ + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + return mesh.axis_names + + +def get_padded_spec(spec, ndim): + """ + Get padded spec for partitioning from arguments' information + """ + if spec is None: + return (None,) * ndim + assert len(spec) <= ndim + return spec + (None,) * (ndim - len(spec)) + + def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): """ A wrapper function to jax.lax.with_sharding_constraint to @@ -40,8 +53,25 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): return jax.lax.with_sharding_constraint(x, pspec) +def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str): + """ + A wrapper function to invoke lax.p* operations, like psum. + """ + if mesh_resource is not None: + _, resource = _get_mesh_info(mesh_resource) + return ops(x, resource) + return x + + +def num_of_devices(): + """ + Get total number of detected devices + """ + return len(jax.devices()) + + @dataclass -class ShardingResource: +class MeshResource: """ A data container to indicate which axis in Mesh for data parallelism and which for tensor parallelism. @@ -54,39 +84,73 @@ class ShardingResource: tp_resource : str, default = None The axis name in Mesh used to split the hidden dimensions along. If it is None, then tensor parallelism is disabled. + fsdp_resource : str, default = None + The axis name in Mesh used to split the batch and weights along. + If it is None, then full-sharded data parallelism is disabled. + pp_resource : str, default = None + The axis name in Mesh used to split model layers. along. + If it is None, then pipeline parallelism is disabled. """ dp_resource: str = None tp_resource: str = None fsdp_resource: str = None + pp_resource: str = None -_GLOBAL_SHARD_RESOURCE = ShardingResource() +_GLOBAL_MESH_RESOURCE = MeshResource() @contextmanager -def global_shard_guard(resource: ShardingResource): +def global_shard_guard(resource: MeshResource): """ - A context manager to switch the global ShardingResource + A context manager to switch the global MeshResource """ - global _GLOBAL_SHARD_RESOURCE - prev_gsr = _GLOBAL_SHARD_RESOURCE + global _GLOBAL_MESH_RESOURCE + prev_gmr = _GLOBAL_MESH_RESOURCE try: - _GLOBAL_SHARD_RESOURCE = resource + _GLOBAL_MESH_RESOURCE = resource yield finally: - _GLOBAL_SHARD_RESOURCE = prev_gsr + _GLOBAL_MESH_RESOURCE = prev_gmr + +def global_mesh_resource() -> MeshResource: + """ + A getter of the global MeshResource + """ + return _GLOBAL_MESH_RESOURCE -def global_shard_resource() -> ShardingResource: + +def all_reduce_sum_along_dp_fsdp(x: jnp.array): """ - A getter of the global ShardingResource + All-Reduce (Sum) along DP and FSDP mesh axes. """ - return _GLOBAL_SHARD_RESOURCE + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource) + return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource) + + +def all_reduce_max_along_all_axes_except_PP(x: jnp.array): + """ + All-Reduce (Max) along all mesh axes. + """ + all_axes = get_all_mesh_axes() + for axis in all_axes: + if axis != global_mesh_resource().pp_resource: + x = lax_paral_op(x, jax.lax.pmax, axis) + return x + + +# Deprecating Items --------------------------------------------------------------- +ShardingResource = MeshResource + +global_shard_resource = global_mesh_resource class MajorShardingType(Enum): r""" The major sharding type to indicate sharding pattern. + .. warning:: + MajorShardingType is deprecating in the near feature. Values ---------- @@ -108,6 +172,8 @@ class MajorShardingType(Enum): class ShardingType(Enum): """ The sharding type to indicate sharding pattern. + .. warning:: + ShardingType is deprecating in the near feature. Values ---------- @@ -130,1058 +196,3 @@ class ShardingType(Enum): TP_ROW = (MajorShardingType.TP, "tp_row") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") - - -def infer_major_sharding_type() -> MajorShardingType: - """ - Infer MajorShardingType from _GLOBAL_SHARD_RESOURCE - """ - gsr = global_shard_resource() - - resources = [gsr.dp_resource, gsr.tp_resource, gsr.fsdp_resource] - for idx, rs in enumerate(resources): - try: - size, _ = _get_mesh_info(rs) - if size <= 1: - resources[idx] = None - except AssertionError as _: - resources[idx] = None - - dp_resource = resources[0] - tp_resource = resources[1] - fsdp_resource = resources[2] - - def dp_enabled(): - return (fsdp_resource is not None) or (dp_resource is not None) - - if dp_enabled() and tp_resource is not None: - return MajorShardingType.DPTP - - if dp_enabled(): - return MajorShardingType.DP - - if tp_resource is not None: - return MajorShardingType.TP - - return MajorShardingType.SINGLE - - -def infer_sharding_type(major_st: MajorShardingType = None) -> Tuple[ShardingType, ShardingType]: - """ - Infer ShardingType via given MajorShardingType - """ - if major_st is None: - major_st = infer_major_sharding_type() - - if major_st is MajorShardingType.DP: - return ShardingType.DP, ShardingType.DP - if major_st is MajorShardingType.TP: - return ShardingType.TP_COL, ShardingType.TP_ROW - if major_st is MajorShardingType.DPTP: - return ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW - return ShardingType.SINGLE, ShardingType.SINGLE - - -def is_dp_enabled(mst: MajorShardingType) -> bool: - """ - is_dp_enabled - """ - return mst in (MajorShardingType.DP, MajorShardingType.DPTP) - - -def is_tp_enabled(mst: MajorShardingType) -> bool: - """ - is_tp_enabled - """ - return mst in (MajorShardingType.TP, MajorShardingType.DPTP) - - -def merge_axis_resources(ars: Tuple[Dict]) -> Dict: - """ - merge_axis_resources - """ - output = {} - for ar in ars: - for key in ar: - if key not in output: - output[key] = ar[key] - else: - assert output[key] == ar[key] - return output - - -@dataclass -class ShardingMeta: - """ShardingMeta""" - in_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]] - out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]] - axis_resources: Dict - input_shapes: Tuple[Tuple[int, ...]] - output_shapes: Tuple[Tuple[int, ...]] - - -class ShardingMetaGenerator: - """ - ShardingMetaGenerator - """ - - def __init__(self): - - def get_single_sharding_meta(*argv, **kwargs) -> ShardingMeta: # pylint: disable=unused-argument - return None - - self.sharding_type_meta_map = { - ShardingType.SINGLE: get_single_sharding_meta, - ShardingType.DP: self.get_dp_sharding_meta, - ShardingType.TP_COL: self.get_tp_col_sharding_meta, - ShardingType.TP_ROW: self.get_tp_row_sharding_meta, - ShardingType.DP_TP_COL: self.get_dp_tp_col_sharding_meta, - ShardingType.DP_TP_ROW: self.get_dp_tp_row_sharding_meta - } - - def get_sharding_meta(self, stype: ShardingType, *argv, **kwargs) -> ShardingMeta: - """get_sharding_meta""" - return self.sharding_type_meta_map[stype](*argv, **kwargs) - - def get_dp_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_dp_sharding_meta""" - raise NotImplementedError - - def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_tp_col_sharding_meta""" - raise NotImplementedError - - def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_tp_row_sharding_meta""" - raise NotImplementedError - - def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_dp_tp_col_sharding_meta""" - raise NotImplementedError - - def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_dp_tp_row_sharding_meta""" - raise NotImplementedError - - -class FP8MetaShardingMetaGenerator(ShardingMetaGenerator): - """ - FP8MetaShardingMetaGenerator - """ - - def get_dp_sharding_meta(self, - num_of_meta: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DP, - num_of_meta, dp_axis_name, - tp_axis_name) - - def get_tp_col_sharding_meta(self, - num_of_meta: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.TP, - num_of_meta, dp_axis_name, - tp_axis_name) - - def get_tp_row_sharding_meta(self, - num_of_meta: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.TP, - num_of_meta, dp_axis_name, - tp_axis_name) - - def get_dp_tp_col_sharding_meta(self, - num_of_meta: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DPTP, - num_of_meta, dp_axis_name, - tp_axis_name) - - def get_dp_tp_row_sharding_meta(self, - num_of_meta: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DPTP, - num_of_meta, dp_axis_name, - tp_axis_name) - - @staticmethod - def _stack_axes_meta(num_of_meta: int, mapping: Dict) -> Tuple: - return tuple(mapping for _ in range(num_of_meta)) - - @staticmethod - def _generate_sharding_meta(type_: MajorShardingType, - num_of_meta: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - - axis_resource = {} - - if is_dp_enabled(type_): - axis_resource[dp_axis_name] = global_shard_resource().dp_resource - - if is_tp_enabled(type_): - axis_resource[tp_axis_name] = global_shard_resource().tp_resource - - return ShardingMeta(FP8MetaShardingMetaGenerator._stack_axes_meta(num_of_meta, {}), - FP8MetaShardingMetaGenerator._stack_axes_meta(num_of_meta, {}), - axis_resource, (), ()) - - -class FusedAttnShardingMetaGenerator(ShardingMetaGenerator): - """ - FusedAttnShardingMetaGenerator - """ - - def get_dp_sharding_meta( - self, - input_shapes: Tuple[Tuple[int, ...]], - output_shapes: Tuple[Tuple[int, ...]], - dp_dims: Tuple[Tuple[int, ...]], - tp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument - dp_axis_name: str = 'data', - tp_axis_name: str = 'model' # pylint: disable=unused-argument - ) -> ShardingMeta: - """get_dp_sharding_meta""" - dummy_tp_dims = [repeat(None), repeat(None)] - return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes, - dp_dims, dummy_tp_dims, - dp_axis_name, None) - - def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_tp_col_sharding_meta""" - return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs) - - def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_tp_row_sharding_meta""" - return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs) - - def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_dp_tp_col_sharding_meta""" - return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs) - - def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta: - """get_dp_tp_row_sharding_meta""" - return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs) - - @staticmethod - def _get_tp_sharding_meta( - input_shapes: Tuple[Tuple[int, ...]], - output_shapes: Tuple[Tuple[int, ...]], - dp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument - tp_dims: Tuple[Tuple[int, ...]], - dp_axis_name: str = 'data', # pylint: disable=unused-argument - tp_axis_name: str = 'model') -> ShardingMeta: - """get_tp_sharding_meta""" - dummy_dp_dims = [repeat(None), repeat(None)] - return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes, - dummy_dp_dims, tp_dims, None, - tp_axis_name) - - @staticmethod - def _get_dptp_sharding_meta(input_shapes: Tuple[Tuple[int, ...]], - output_shapes: Tuple[Tuple[int, ...]], - dp_dims: Tuple[Tuple[int, ...]], - tp_dims: Tuple[Tuple[int, ...]], - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """get_dp_tp_sharding_meta""" - - dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource) - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - - input_dp_dims, output_dp_dims = dp_dims - input_tp_dims, output_tp_dims = tp_dims - - input_new_shapes = [] - in_axes = [] - - for input_shape, dp_dim, tp_dim in zip(input_shapes, input_dp_dims, input_tp_dims): - in_axis = {} - if dp_dim is not None and input_shape is not None: - in_axis[dp_dim] = dp_axis_name - assert input_shape[dp_dim] % dp_size == 0, \ - f"The dimension of batch in input_shape should be a multiple of " \ - f"data parallelism size, but got {input_shape[dp_dim]=} and {dp_size=}." - input_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size, - *input_shape[dp_dim + 1:]) - - # the input shape has been expanded for dp_dim, tp_dim should +1 if tp_dim >= dp_dim - if tp_dim is not None and tp_dim >= dp_dim: - tp_dim = tp_dim + 1 - - if tp_dim is not None and input_shape is not None: - in_axis[tp_dim] = tp_axis_name - assert input_shape[tp_dim] % tp_size == 0, \ - f"The dimension of tensor parallel in input_shape should be a multiple of " \ - f"tensor parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}." - input_shape = (*input_shape[:tp_dim], tp_size, input_shape[tp_dim] // tp_size, - *input_shape[tp_dim + 1:]) - - in_axes.append(in_axis) - input_new_shapes.append(input_shape) - - output_new_shapes = output_shapes - out_axes = [] - for dp_dim, tp_dim in zip(output_dp_dims, output_tp_dims): - out_axis = {} - if dp_dim is not None: - out_axis[dp_dim] = dp_axis_name - if tp_dim is not None and tp_dim >= dp_dim: - tp_dim = tp_dim + 1 - if tp_dim is not None: - out_axis[tp_dim] = tp_axis_name - out_axes.append(out_axis) - - assert len(out_axes) == 1, "Only allow single output at this moment." - assert len(output_new_shapes) == 1, "Only allow single output at this moment." - out_axes = out_axes[0] - output_new_shapes = output_new_shapes[0] - - axis_resources = {} - if dp_axis_name is not None: - axis_resources[dp_axis_name] = dp_mesh_axis - if tp_axis_name is not None: - axis_resources[tp_axis_name] = tp_mesh_axis - - return ShardingMeta(tuple(in_axes), out_axes, axis_resources, input_new_shapes, - output_new_shapes) - - -class DotShardingMetaGenerator(ShardingMetaGenerator): - """ - DotShardingMetaGenerator - """ - - def get_dp_sharding_meta( - self, - a_shape: Tuple, - b_shape: Tuple, - batch_dim_of_a: int, - model_dim_of_a: int, # pylint: disable=unused-argument - model_dim_of_b: int, # pylint: disable=unused-argument - contracting_dims: Tuple[Sequence[int], Sequence[int]], - dp_axis_name: str = 'data', - tp_axis_name: str = 'model' # pylint: disable=unused-argument - ) -> ShardingMeta: - DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None, - contracting_dims) - - out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims) - out_batch_dim = batch_dim_of_a - - dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource) - assert a_shape[batch_dim_of_a] % dp_size == 0, \ - f"The dimension of batch in a_shape should be a multiple of data parallelism size," \ - f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}." - a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, -1, *a_shape[batch_dim_of_a + 1:]) - return ShardingMeta(({ - batch_dim_of_a: dp_axis_name - }, {}), ({ - out_batch_dim: dp_axis_name - }), {dp_axis_name: dp_mesh_axis}, [a_new_shape, b_shape], [out_shape]) - - def get_tp_col_sharding_meta( - self, - a_shape: Tuple, - b_shape: Tuple, - batch_dim_of_a: int, - model_dim_of_a: int, # pylint: disable=unused-argument - model_dim_of_b: int, - contracting_dims: Tuple[Sequence[int], Sequence[int]], - dp_axis_name: str = 'data', # pylint: disable=unused-argument - tp_axis_name: str = 'model') -> ShardingMeta: - DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None, - contracting_dims) - - out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims) - - out_model_idx = len(out_shape) - (len(b_shape) - model_dim_of_b) - - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - assert b_shape[model_dim_of_b] % tp_size == 0, \ - f"The dimension of model parallelism in b_shape should be a multiple of " \ - f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}." - b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size, - *b_shape[model_dim_of_b + 1:]) - return ShardingMeta(({}, { - model_dim_of_b: tp_axis_name - }), ({ - out_model_idx: tp_axis_name - }), {tp_axis_name: tp_mesh_axis}, [a_shape, b_new_shape], [out_shape]) - - def get_tp_row_sharding_meta( - self, - a_shape: Tuple, - b_shape: Tuple, - batch_dim_of_a: int, - model_dim_of_a: int, - model_dim_of_b: int, - contracting_dims: Tuple[Sequence[int], Sequence[int]], - dp_axis_name: str = 'data', # pylint: disable=unused-argument - tp_axis_name: str = 'model') -> ShardingMeta: - DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, model_dim_of_a, - contracting_dims) - - out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims) - - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - assert a_shape[model_dim_of_a] % tp_size == 0, \ - f"The dimension of model parallelism in a_shape should be a multiple of " \ - f"tensor parallelism size,but got {a_shape[model_dim_of_a]=} and {tp_size=}." - assert b_shape[model_dim_of_b] % tp_size == 0, \ - f"The dimension of model parallelism in b_shape should be a multiple of " \ - f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}." - a_new_shape = (*a_shape[:model_dim_of_a], tp_size, a_shape[model_dim_of_a] // tp_size, - *a_shape[model_dim_of_a + 1:]) - b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size, - *b_shape[model_dim_of_b + 1:]) - return ShardingMeta(({ - model_dim_of_a: tp_axis_name - }, { - model_dim_of_b: tp_axis_name - }), ({}), {tp_axis_name: tp_mesh_axis}, [a_new_shape, b_new_shape], [out_shape]) - - def get_dp_tp_col_sharding_meta( - self, - a_shape: Tuple, - b_shape: Tuple, - batch_dim_of_a: int, - model_dim_of_a: int, # pylint: disable=unused-argument - model_dim_of_b: int, - contracting_dims: Tuple[Sequence[int], Sequence[int]], - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None, - contracting_dims) - - out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims) - - out_model_idx = len(out_shape) + 1 - (len(b_shape) - model_dim_of_b) - - dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource) - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - assert a_shape[batch_dim_of_a] % dp_size == 0, \ - f"The dimension of batch in a_shape should be a multiple of data parallelism size," \ - f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}." - assert b_shape[model_dim_of_b] % tp_size == 0, \ - f"The dimension of model parallelism in b_shape should be a multiple of " \ - f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}." - a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, a_shape[batch_dim_of_a] // dp_size, - *a_shape[batch_dim_of_a + 1:]) - b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size, - *b_shape[model_dim_of_b + 1:]) - return ShardingMeta(({ - batch_dim_of_a: dp_axis_name - }, { - model_dim_of_b: tp_axis_name - }), ({ - batch_dim_of_a: dp_axis_name, - out_model_idx: tp_axis_name - }), { - dp_axis_name: dp_mesh_axis, - tp_axis_name: tp_mesh_axis - }, [a_new_shape, b_new_shape], [out_shape]) - - def get_dp_tp_row_sharding_meta(self, - a_shape: Tuple, - b_shape: Tuple, - batch_dim_of_a: int, - model_dim_of_a: int, - model_dim_of_b: int, - contracting_dims: Tuple[Sequence[int], Sequence[int]], - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, model_dim_of_a, - contracting_dims) - - out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims) - - dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource) - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - assert a_shape[batch_dim_of_a] % dp_size == 0, \ - f"The dimension of batch in a_shape should be a multiple of data parallelism size," \ - f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}." - assert a_shape[model_dim_of_a] % tp_size == 0, \ - f"The dimension of model parallelism in a_shape should be a multiple of " \ - f"tensor parallelism size,but got {a_shape[model_dim_of_a]=} and {tp_size=}." - assert b_shape[model_dim_of_b] % tp_size == 0, \ - f"The dimension of model parallelism in b_shape should be a multiple of " \ - f"tensor parallelism size,but {b_shape[model_dim_of_b]=} and {tp_size=}." - a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, a_shape[batch_dim_of_a] // dp_size, - *a_shape[batch_dim_of_a + 1:model_dim_of_a], tp_size, - a_shape[model_dim_of_a] // tp_size, *a_shape[model_dim_of_a + 1:]) - b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size, - *b_shape[model_dim_of_b + 1:]) - return ShardingMeta( - ( - { - batch_dim_of_a: - dp_axis_name, - # "model_dim_of_a+1" is the index to tp_size in a_new_shape - model_dim_of_a + 1: - tp_axis_name - }, - { - model_dim_of_b: tp_axis_name - }), - ({ - batch_dim_of_a: dp_axis_name - }), - { - dp_axis_name: dp_mesh_axis, - tp_axis_name: tp_mesh_axis - }, - [a_new_shape, b_new_shape], - [out_shape]) - - @staticmethod - def _is_supported( - a_shape: Tuple, # pylint: disable=unused-argument - b_shape: Tuple, # pylint: disable=unused-argument - batch_dim_of_a: int, - model_dim_of_a: int, - contracting_dims: Tuple[Sequence[int], Sequence[int]], - ): - assert batch_dim_of_a not in contracting_dims[0], \ - "batch_dim_of_a should be one of contracting_dims[0]" - assert batch_dim_of_a >= 0, \ - "Only support non-negative value of batch_dim_of_a." - if model_dim_of_a is not None: - assert model_dim_of_a >= 0, \ - "Only support non-negative value of model_dim_of_a" - assert model_dim_of_a > batch_dim_of_a, \ - "Only support the case that model_dim_of_a > batch_dim_of_a." - - @staticmethod - def _infer_output_shape( - a_shape: Tuple, - b_shape: Tuple, - contracting_dims: Tuple[Sequence[int], Sequence[int]], - ): - lhs_contracting_dims, rhs_contracting_dims = contracting_dims - return (*a_shape[:min(lhs_contracting_dims)], *b_shape[max(rhs_contracting_dims) + 1:]) - - -class ElementwiseShardingMetaGenerator(ShardingMetaGenerator): - """ - ElementwiseShardingMetaGenerator - """ - - def get_dp_sharding_meta( - self, - input_shape: Tuple, - other_shape: Tuple, - batch_dim: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model' # pylint: disable=unused-argument - ) -> ShardingMeta: - """get_dp_sharding_meta""" - ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, batch_dim) - - dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource) - - assert input_shape[batch_dim] % dp_size == 0, \ - f"The dimension of batch in input_shape should be a multiple of data parallelism " \ - f"size, but got {input_shape[batch_dim]=} and {dp_size=}." - input_new_shape = (*input_shape[:batch_dim], dp_size, -1, *input_shape[batch_dim + 1:]) - in_axes = [{batch_dim: dp_axis_name}] - input_new_shapes = [input_new_shape] - if other_shape is not None: - input_new_shapes.append(other_shape) - in_axes.append({}) - - return ShardingMeta(tuple(in_axes), ({ - batch_dim: dp_axis_name - }), {dp_axis_name: dp_mesh_axis}, input_new_shapes, [input_shape]) - - def get_tp_col_sharding_meta( - self, - input_shape: Tuple, - other_shape: Tuple, - batch_dim: int, # pylint: disable=unused-argument - dp_axis_name: str = 'data', # pylint: disable=unused-argument - tp_axis_name: str = 'model' # pylint: disable=unused-argument - ) -> ShardingMeta: - """get_tp_col_sharding_meta""" - ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, 0) - in_axes = [{}] - input_new_shapes = [input_shape] - if other_shape is not None: - in_axes.append({}) - input_new_shapes.append(other_shape) - - return ShardingMeta(tuple(in_axes), ({}), {}, input_new_shapes, [input_shape]) - - def get_tp_row_sharding_meta( - self, - input_shape: Tuple, - other_shape: Tuple, - batch_dim: int, # pylint: disable=unused-argument - dp_axis_name: str = 'data', # pylint: disable=unused-argument - tp_axis_name: str = 'model') -> ShardingMeta: - """get_tp_row_sharding_meta""" - ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, 0) - - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - - assert input_shape[-1] % tp_size == 0, \ - f"The last dimension in input_shape should be a multiple of tensor parallelism size," \ - f" but got {input_shape[-1]=} and {tp_size=}." - input_new_shape = (*input_shape[:-1], tp_size, -1) - - in_axes = [{ - # "len(a_new_shape)-2" is the index to tp_size in a_new_shape - len(input_new_shape) - 2: - tp_axis_name - }] - input_new_shapes = [input_new_shape] - - if other_shape is not None: - assert other_shape[0] % tp_size == 0, \ - f"The first dimension in other_shape should be a multiple of tensor parallelism size," \ - f" but got {other_shape[0]=} and {tp_size=}." - other_new_shape = (tp_size, -1) - in_axes.append({0: tp_axis_name}) - input_new_shapes.append(other_new_shape) - - return ShardingMeta(tuple(in_axes), ({ - len(input_new_shape) - 2: tp_axis_name - }), {tp_axis_name: tp_mesh_axis}, input_new_shapes, [input_shape]) - - def get_dp_tp_col_sharding_meta(self, - input_shape: Tuple, - other_shape: Tuple, - batch_dim: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """get_dp_tp_col_sharding_meta""" - return self.get_dp_sharding_meta(input_shape, other_shape, batch_dim, dp_axis_name, - tp_axis_name) - - def get_dp_tp_row_sharding_meta(self, - input_shape: Tuple, - other_shape: Tuple, - batch_dim: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """get_dp_tp_row_sharding_meta""" - ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, batch_dim) - - dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource) - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - - assert input_shape[batch_dim] % dp_size == 0, \ - f"The dimension of batch in input_shape should be a multiple of data parallelism" \ - f"size, but got {input_shape[batch_dim]=} and {dp_size=}." - assert input_shape[-1] % tp_size == 0, \ - f"The last dimension in input_shape should be a multiple of tensor parallelism size," \ - f" but got {input_shape[-1]=} and {tp_size=}." - input_new_shape = (*input_shape[:batch_dim], dp_size, -1, *input_shape[batch_dim + 1:-1], - tp_size, input_shape[-1] // tp_size) - - in_axes = [{ - batch_dim: - dp_axis_name, - # "len(a_new_shape)-2" is the index to tp_size in a_new_shape - len(input_new_shape) - 2: - tp_axis_name - }] - input_new_shapes = [input_new_shape] - - other_new_shape = other_shape - if other_shape is not None: - assert other_shape[0] % tp_size == 0, \ - f"The first dimension in other_shape should be a multiple of tensor parallelism size," \ - f" but got {other_shape[0]=} and {tp_size=}." - other_new_shape = (tp_size, -1) - in_axes.append({0: tp_axis_name}) - input_new_shapes.append(other_new_shape) - - return ShardingMeta(tuple(in_axes), ({ - batch_dim: dp_axis_name, - len(input_new_shape) - 2: tp_axis_name - }), { - dp_axis_name: dp_mesh_axis, - tp_axis_name: tp_mesh_axis - }, input_new_shapes, [input_shape]) - - @staticmethod - def _is_supported(input_shape: Tuple, other_shape: Tuple, batch_dim: int): - if other_shape is not None: - assert len(other_shape) == 1, "Only support 1 dimension of other_shapes currently." - assert input_shape[-1] == other_shape[0], \ - f"input_shape[-1] should equal to oshape[0], " \ - f"but got {input_shape[-1]} and {other_shape[0]}." - - assert batch_dim < len(input_shape)-1, \ - "batch_dim cannot be the latest dim" - - -class SoftmaxShardingMetaGenerator(ShardingMetaGenerator): - """ - SoftmaxShardingMetaGenerator - """ - - def get_dp_sharding_meta( - self, - input_shape: Tuple, - dp_dim: int = 0, - tp_dim: int = 1, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model' # pylint: disable=unused-argument - ) -> ShardingMeta: - """get_dp_sharding_meta""" - SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim) - - dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource) - - assert input_shape[dp_dim] % dp_size == 0, \ - f"The dimension of batch in input_shape should be a multiple of data parallelism " \ - f"size, but got {input_shape[dp_dim]=} and {dp_size=}." - input_new_shape = (*input_shape[:dp_dim], dp_size, -1, *input_shape[dp_dim + 1:]) - in_axes = [{dp_dim: dp_axis_name}] - input_new_shapes = [input_new_shape] - - out_axes = in_axes[0] - - return ShardingMeta(tuple(in_axes), out_axes, {dp_axis_name: dp_mesh_axis}, - input_new_shapes, [input_shape]) - - def get_tp_col_sharding_meta(self, - input_shape: Tuple, - dp_dim: int = 0, - tp_dim: int = 1, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """get_tp_col_sharding_meta""" - return SoftmaxShardingMetaGenerator._get_tp_sharding_meta(input_shape, dp_dim, tp_dim, - dp_axis_name, tp_axis_name) - - def get_tp_row_sharding_meta(self, - input_shape: Tuple, - dp_dim: int = 0, - tp_dim: int = 1, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """get_tp_row_sharding_meta""" - return SoftmaxShardingMetaGenerator._get_tp_sharding_meta(input_shape, dp_dim, tp_dim, - dp_axis_name, tp_axis_name) - - def get_dp_tp_col_sharding_meta(self, - input_shape: Tuple, - dp_dim: int = 0, - tp_dim: int = 1, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """get_dp_tp_col_sharding_meta""" - return SoftmaxShardingMetaGenerator._get_dptp_sharding_meta(input_shape, dp_dim, tp_dim, - dp_axis_name, tp_axis_name) - - def get_dp_tp_row_sharding_meta(self, - input_shape: Tuple, - dp_dim: int = 0, - tp_dim: int = 1, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """get_dp_tp_row_sharding_meta""" - return SoftmaxShardingMetaGenerator._get_dptp_sharding_meta(input_shape, dp_dim, tp_dim, - dp_axis_name, tp_axis_name) - - @staticmethod - def _is_supported(input_shape: Tuple, dp_dim: int, tp_dim: int): - assert len(input_shape) == 4 - assert dp_dim == 0 - assert tp_dim == 1 - - @staticmethod - def _get_tp_sharding_meta( - input_shape: Tuple, - dp_dim: int = 0, - tp_dim: int = 1, - dp_axis_name: str = 'data', # pylint: disable=unused-argument - tp_axis_name: str = 'model' # pylint: disable=unused-argument - ) -> ShardingMeta: - """get_tp_sharding_meta""" - SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim) - - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - - assert input_shape[tp_dim] % tp_size == 0, \ - f"The dimension of tensor parallel in input_shape should be a multiple of data " \ - f"parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}." - input_new_shape = (*input_shape[:tp_dim], tp_size, -1, *input_shape[tp_dim + 1:]) - in_axes = [{tp_dim: tp_axis_name}] - input_new_shapes = [input_new_shape] - - out_axes = in_axes[0] - - return ShardingMeta(tuple(in_axes), out_axes, {tp_axis_name: tp_mesh_axis}, - input_new_shapes, [input_shape]) - - @staticmethod - def _get_dptp_sharding_meta(input_shape: Tuple, - dp_dim: int = 0, - tp_dim: int = 1, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """get_dp_tp_sharding_meta""" - SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim) - - dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource) - tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource) - - assert input_shape[dp_dim] % dp_size == 0, \ - f"The dimension of batch in input_shape should be a multiple of data parallelism " \ - f"size, but got {input_shape[dp_dim]=} and {dp_size=}." - assert input_shape[tp_dim] % tp_size == 0, \ - f"The dimension of tensor parallel in input_shape should be a multiple of data " \ - f"parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}." - - input_new_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size, - *input_shape[dp_dim + 1:tp_dim], tp_size, input_shape[tp_dim] // tp_size, - *input_shape[tp_dim + 1:]) - - in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}] - input_new_shapes = [input_new_shape] - - out_axes = in_axes[0] - - return ShardingMeta(tuple(in_axes), out_axes, { - dp_axis_name: dp_mesh_axis, - tp_axis_name: tp_mesh_axis - }, input_new_shapes, [input_shape]) - - -def get_fp8_meta_sharding_meta(stype: ShardingType, - num_of_meta: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """ - get_fp8_meta_sharding_meta - """ - return FP8MetaShardingMetaGenerator().get_sharding_meta(stype, num_of_meta, dp_axis_name, - tp_axis_name) - - -def get_dot_sharding_meta(stype: ShardingType, - a_shape: Tuple, - b_shape: Tuple, - batch_dim_of_a: int, - model_dim_of_a: int, - model_dim_of_b: int, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """ - get_dot_sharding_meta - """ - if stype in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW): - assert model_dim_of_b <= max(contracting_dims[1]), \ - f"The dimension of model parallelism in b_shape should be smaller than the max of" \ - f" contracting_dims[1], but got {model_dim_of_b=} and {contracting_dims[1]=}." - if stype in (ShardingType.TP_COL, ShardingType.DP_TP_COL): - assert model_dim_of_b > max(contracting_dims[1]), \ - f"The dimension of model parallelism in b_shape should be larger than the max of" \ - f" contracting_dims[1], but got {model_dim_of_b=} and {contracting_dims[1]=}." - return DotShardingMetaGenerator().get_sharding_meta(stype, a_shape, b_shape, batch_dim_of_a, - model_dim_of_a, model_dim_of_b, - contracting_dims, dp_axis_name, - tp_axis_name) - - -def get_elementwise_sharding_meta(stype: ShardingType, - input_shape: Tuple, - other_shape: Tuple, - batch_dim: int, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """ - get_elementwise_sharding_meta - """ - return ElementwiseShardingMetaGenerator().get_sharding_meta(stype, input_shape, other_shape, - batch_dim, dp_axis_name, - tp_axis_name) - - -def get_softmax_sharding_meta(stype: ShardingType, - input_shape: Tuple, - dp_dim: int = 0, - tp_dim: int = 1, - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """ - get_softmax_sharding_meta - """ - return SoftmaxShardingMetaGenerator().get_sharding_meta(stype, input_shape, dp_dim, tp_dim, - dp_axis_name, tp_axis_name) - - -def get_fused_attn_sharding_meta(stype: ShardingType, - input_shapes: Tuple[Tuple[int, ...]], - output_shapes: Tuple[Tuple[int, ...]], - dp_dims: Tuple[Tuple[int, ...]], - tp_dims: Tuple[Tuple[int, ...]], - dp_axis_name: str = 'data', - tp_axis_name: str = 'model') -> ShardingMeta: - """ - get_self_fused_attn_sharding_meta - """ - return FusedAttnShardingMetaGenerator().get_sharding_meta(stype, input_shapes, output_shapes, - dp_dims, tp_dims, dp_axis_name, - tp_axis_name) - - -def extend_fsdp_sharding_meta(sharding_meta: ShardingMeta, - weight_fsdp_dim_map: Dict[int, int]) -> Tuple[ShardingMeta, str]: - """ - Extending the given ShardingMeta to be compatible with FSDP (ZeRO3) sharding pattern. - - .. note:: - The extending helper assumes the first shape in sharding_meta.input_shapes - corresponding to the input tensor. Please be sure that 0-idx is in - `weight_fsdp_dim_map`. - - Parameters - ---------- - sharding_meta : ShardingMeta - the sharding meta object to extend with FSDP. - weight_fsdp_dim_map: Dict[int, int] - The dict, which key is idx of sharding_meta.input_shapes and value is the dimension - to extend FSDP. default is None, means no other sharding_meta.input_shapes to extend. - - Returns - ------- - updated_sharding_meta : ShardingMeta - a sharding_meta with the FSDP extenstion. - fsdp_axis_name: str - The name of FSDP named axis for further xmap projection. - """ - assert 0 in weight_fsdp_dim_map, \ - "0-idx is required to be in 'weight_fsdp_dim_map' for the input." - - mst = infer_major_sharding_type() - if mst is MajorShardingType.SINGLE: - return sharding_meta, "" - - gsr = global_shard_resource() - dp_mesh_axis = gsr.dp_resource - fsdp_mesh_axis = gsr.fsdp_resource - - if fsdp_mesh_axis == dp_mesh_axis: - return sharding_meta, "" - if fsdp_mesh_axis is None: - return sharding_meta, "" - - fsdp_dim_size, _ = _get_mesh_info(fsdp_mesh_axis) - fsdp_axis_name = "fsdp" - - def get_idx_to_extend(sharded_indices, target_idx): - idx_to_extend = target_idx - for i in sharded_indices: - if i <= target_idx: - idx_to_extend += 1 - return idx_to_extend - - def extend_exist_sharding(idx, shape): - remain_size = shape[idx] - assert remain_size == -1 or remain_size % fsdp_dim_size == 0 - remain_size = remain_size // fsdp_dim_size - new_shape = tuple([*shape[:idx], fsdp_dim_size, remain_size, *shape[idx + 1:]]) - return new_shape - - new_input_shapes = [] - new_in_axes = [] - for i, shape in enumerate(sharding_meta.input_shapes): - idx_to_extend = -1 - if i == 0: # Assume first shape corresponds to input - input_dp_dim = weight_fsdp_dim_map[i] - # idx_to_extend = input_dp_dim + 1 if is_dp_enabled(mst) else input_dp_dim - idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()), input_dp_dim) - new_shape = extend_exist_sharding(idx_to_extend, shape) - - # assume one output only and have the same batch sharding like input - assert isinstance(sharding_meta.out_axes, dict) - new_out_axes = {} - for key in sharding_meta.out_axes: - if key < idx_to_extend: - new_out_axes[key] = sharding_meta.out_axes[key] - else: - new_out_axes[key + 1] = sharding_meta.out_axes[key] - new_out_axes[idx_to_extend] = fsdp_axis_name - sharding_meta.out_axes = new_out_axes - else: - new_shape = shape - if i in weight_fsdp_dim_map: - idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()), - weight_fsdp_dim_map[i]) - if weight_fsdp_dim_map[i] in sharding_meta.in_axes[i]: - new_shape = extend_exist_sharding(idx_to_extend, shape) - else: - assert shape[idx_to_extend] % fsdp_dim_size == 0 - remain_dim_size = shape[idx_to_extend] // fsdp_dim_size - new_shape = tuple([ - *shape[:idx_to_extend], fsdp_dim_size, remain_dim_size, - *shape[idx_to_extend + 1:] - ]) - if idx_to_extend >= 0: - new_ia = {} - for key in sharding_meta.in_axes[i]: - if key < idx_to_extend: - new_ia[key] = sharding_meta.in_axes[i][key] - else: - new_ia[key + 1] = sharding_meta.in_axes[i][key] - new_ia[idx_to_extend] = fsdp_axis_name - else: - new_ia = sharding_meta.in_axes[i] - - new_input_shapes.append(new_shape) - new_in_axes.append(new_ia) - - sharding_meta.input_shapes = tuple(new_input_shapes) - sharding_meta.in_axes = tuple(new_in_axes) - - sharding_meta.axis_resources[fsdp_axis_name] = fsdp_mesh_axis - return sharding_meta, fsdp_axis_name - - -def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...], - out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]], - axis_resources: Dict, inputs: Tuple): - """ - xmap_runner - """ - assert isinstance(inputs, tuple) - assert isinstance(in_axes, tuple) - - mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh - fake_in_axes = {} - fake_axis_resource = {} - - # Fake related setup is a workaround to "NotImplementedError: - # Collectives in manually partitioned computations are only supported - # when all mesh axes are partitioned manually (no partial automatic - # sharding). Make sure that you mention all mesh axes in axis_resources!" - fake_idx_counter = 0 - for mesh_axis_names in mesh.axis_names: - if mesh_axis_names not in axis_resources.values(): - fake_idx_counter += 1 - fake_axis_name = f"{mesh_axis_names}_fake_{fake_idx_counter}" - fake_in_axes[fake_idx_counter] = fake_axis_name - fake_axis_resource[fake_axis_name] = mesh_axis_names - - fake_input = jnp.zeros(tuple(64 for _ in range(len(fake_in_axes) + 1))) - - xmapped = xmap(lambda func_input, _: func(*func_input), - in_axes=(in_axes, fake_in_axes), - out_axes=out_axes, - axis_resources={ - **axis_resources, - **fake_axis_resource - }) - output = xmapped(inputs, fake_input) - return output diff --git a/transformer_engine/jax/softmax.py b/transformer_engine/jax/softmax.py index 9f03889f1a..cd283f0ea2 100644 --- a/transformer_engine/jax/softmax.py +++ b/transformer_engine/jax/softmax.py @@ -18,11 +18,6 @@ from .cpp_extensions import ScaledSoftmaxFwdPrimitive from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive -from .sharding import get_softmax_sharding_meta, ShardingType, ShardingMeta -from .sharding import xmap_runner, extend_fsdp_sharding_meta - -jax.config.update('experimental_xmap_spmd_lowering', True) -jax.config.update('experimental_xmap_spmd_lowering_manual', True) class SoftmaxType(Enum): @@ -48,100 +43,47 @@ def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: in raise NotImplementedError -def softmax(inputs: jnp.ndarray, +def softmax(logits: jnp.ndarray, mask: Optional[jnp.ndarray] = None, scale_factor: Optional[float] = 1.0, - softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED, - sharding_type: ShardingType = ShardingType.SINGLE, - dp_dim_index: int = 0, - tp_dim_index: int = 1): + softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED): """ Softmax wrapper """ - assert dp_dim_index == 0, \ - "Only softmax support batch dim in the first place currently." - assert tp_dim_index == 1, \ - "Only softmax support head dim in the second place currently." - - assert mask is None or mask.shape[tp_dim_index] == 1 - - if sharding_type is ShardingType.SINGLE: - outputs = _softmax(inputs, mask, scale_factor, softmax_type) - else: - dp_axis_name = "batch" - tp_axis_name = "model" - - sharding_meta = get_softmax_sharding_meta(sharding_type, - inputs.shape, - dp_dim=dp_dim_index, - tp_dim=tp_dim_index, - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name) - - sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index}) - - inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input - mask_ = mask - mask_in_axis = {} - if mask_ is not None: - - if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW): - # If mask is head broadcastable (heads == 1), - # then it equals to DP sharding. - mask_sharding_meta = get_softmax_sharding_meta(ShardingType.DP, - mask_.shape, - dp_dim=dp_dim_index, - tp_dim=tp_dim_index, - dp_axis_name=dp_axis_name, - tp_axis_name=tp_axis_name) - else: - mask_sharding_meta = ShardingMeta([{}], {}, {}, [mask_.shape], mask_.shape) - - mask_sharding_meta, _ = extend_fsdp_sharding_meta(mask_sharding_meta, {0: dp_dim_index}) - mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0]) - mask_in_axis = mask_sharding_meta.in_axes[0] - - partial_softmax = partial(_softmax, scale_factor=scale_factor, softmax_type=softmax_type) - - in_axes = (sharding_meta.in_axes[0], mask_in_axis) - outputs = xmap_runner(partial_softmax, in_axes, sharding_meta.out_axes, - sharding_meta.axis_resources, (inputs_, mask_)) - - outputs = jnp.reshape(outputs, sharding_meta.output_shapes[0]) - - return outputs + output = _softmax(logits, mask, scale_factor, softmax_type) + return output @partial(jax.custom_vjp, nondiff_argnums=(2, 3)) -def _softmax(inputs, mask, scale_factor, softmax_type): - output, _ = _softmax_fwd(inputs, mask, scale_factor, softmax_type) +def _softmax(logits, mask, scale_factor, softmax_type): + + output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type) return output -def _softmax_fwd(inputs, mask, scale_factor, softmax_type): +def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): if softmax_type is SoftmaxType.SCALED_MASKED: assert mask is not None - outputs = scaled_masked_softmax_fwd(inputs, mask, scale_factor) + output = scaled_masked_softmax_fwd(logits, mask, scale_factor) elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: - outputs = scaled_upper_triang_masked_softmax_fwd(inputs, scale_factor) + output = scaled_upper_triang_masked_softmax_fwd(logits, scale_factor) else: - outputs = scaled_softmax_fwd(inputs, scale_factor) + output = scaled_softmax_fwd(logits, scale_factor) - return outputs, (outputs, mask) + return output, (output,) -def _softmax_bwd(scale_factor, softmax_type, ctx, grad_outputs): - softmax_outputs, mask = ctx +def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz): + softmax_output, = ctx if softmax_type is SoftmaxType.SCALED_MASKED: - assert mask is not None - dgrad = scaled_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor) + dgrad = scaled_masked_softmax_bwd(dz, softmax_output, scale_factor) elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: - dgrad = scaled_upper_triang_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor) + dgrad = scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor) else: - dgrad = scaled_softmax_bwd(grad_outputs, softmax_outputs, scale_factor) + dgrad = scaled_softmax_bwd(dz, softmax_output, scale_factor) return (dgrad, None) -_softmax.defvjp(_softmax_fwd, _softmax_bwd) +_softmax.defvjp(_softmax_fwd_rule, _softmax_bwd_rule)