From d427a9b8b057c98bd224e3b77fcd418f0e128b14 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 18 Jul 2023 10:27:03 -0700 Subject: [PATCH 01/12] add TE support --- paxml/contrib/gpu/scripts_gpu/configs.py | 22 +- paxml/contrib/gpu/scripts_gpu/te_helper.py | 324 +++++++++++++++++++++ paxml/main.py | 61 ++-- paxml/tasks_lib.py | 3 +- paxml/trainer_lib.py | 36 ++- 5 files changed, 396 insertions(+), 50 deletions(-) create mode 100644 paxml/contrib/gpu/scripts_gpu/te_helper.py diff --git a/paxml/contrib/gpu/scripts_gpu/configs.py b/paxml/contrib/gpu/scripts_gpu/configs.py index cb5ad5210..16d412587 100644 --- a/paxml/contrib/gpu/scripts_gpu/configs.py +++ b/paxml/contrib/gpu/scripts_gpu/configs.py @@ -28,6 +28,7 @@ from paxml.contrib.gpu.scripts_gpu.tasks import LambadaDataset from paxml.contrib.gpu.scripts_gpu.tasks import PileUnsupervisedDataset from paxml.tasks.lm.model_params import maybe_setup_moe_params +from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper from paxml.tasks.lm.params.c4 import TransformerLmSpmdAdam from paxml.tasks.lm.params.lm_cloud import SyntheticDataset from praxis import base_layer @@ -116,7 +117,7 @@ class GPT126MBase(TransformerLmSpmdAdam): MAX_SEQ_LEN = 2048 VOCAB_SIZE = 50304 - PACKED_INPUT = True + PACKED_INPUT = False PERCORE_BATCH_SIZE = 4 NUM_LAYERS = 12 @@ -171,10 +172,21 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: fdl.get_callable(stacked_p), transformers.StackedTransformerRepeated ): stacked_p = stacked_p.block - transformer_layer_p = stacked_p.transformer_layer_params_tpl - transformer_layer_p.ln_tpl.reductions_in_fp32 = True - transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True + task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True + if not TransformerEngineHelper.is_enabled_te(): + transformer_layer_p = stacked_p.transformer_layer_params_tpl + transformer_layer_p.ln_tpl.reductions_in_fp32 = True + transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True + else: + stacked_p = TransformerEngineHelper.get_stack_transformer( + stacked_p, jnp.dtype(self.FPROP_DTYPE)) + if issubclass(fdl.get_callable(model_p.lm_tpl.stacked_transformer_tpl), + transformers.StackedTransformerRepeated): + model_p.lm_tpl.stacked_transformer_tpl.block = stacked_p + else: + model_p.lm_tpl.stacked_transformer_tpl = stacked_p + model_p.params_init = WeightInit.Gaussian(self.INIT_STD) softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) @@ -239,7 +251,7 @@ class GPT175BBase(GPT126MBase): # Known as MLP_DIM in t5x HIDDEN_DIMS = MODEL_DIMS * 4 # Defaults to MODEL_DIMS // NUM_HEADS. - DIMS_PER_HEAD = None + DIMS_PER_HEAD = 128 # Known as NUM_EMBEDDINGS in t5x VOCAB_SIZE = 50257 USE_REPEATED_LAYER = True diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py new file mode 100644 index 000000000..d44ca6736 --- /dev/null +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -0,0 +1,324 @@ +import os +from contextlib import contextmanager +from typing import Optional, Sequence + +import jax +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from praxis import base_layer +from praxis import pax_fiddle +from praxis import pytypes +from praxis.layers import transformers +from praxis.layers import stochastics + +try: + import transformer_engine.jax as te + import transformer_engine.jax.flax as te_flax + import transformer_engine.jax.praxis as te_praxis + from transformer_engine.common import recipe + _IS_TRANSFORMER_ENGINE_INSTALLED = True + DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME] + import praxis.layers.repeats as praxis_repeat + # This is to make Repeat module correctly generate collections we need. + praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes + te.fp8.FP8Helper.FP8_COLLECTION_NAME:0}) + +except ModuleNotFoundError as e: + _IS_TRANSFORMER_ENGINE_INSTALLED = False + DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + + +LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] +JTensor = pytypes.JTensor + +class StackedTransformer(transformers.StackedTransformer): + """A mirror of StackedTransformer layers in Praxis.""" + + def setup(self) -> None: + + assert self.num_layers > 0 + assert self.model_dims > 0 + assert self.hidden_dims > 0 + assert self.num_heads > 0 + assert 0.0 <= self.dropout_prob < 1.0 + assert 0.0 <= self.input_dropout_prob < 1.0 + + def _layer_params(i): + """Construct i-th layer params.""" + if isinstance(self.transformer_layer_params_tpl, Sequence): + factor = self.num_layers // len(self.transformer_layer_params_tpl) + ii = i // factor + p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii]) + else: + p_i = self._clone_layer_params(self.transformer_layer_params_tpl) + p_i.name = f'layer_{i}' + + p_i.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple()) + p_i.layer_type = te_praxis.TransformerLayerType.DECODER if self.use_cross_attention \ + else te_praxis.TransformerLayerType.ENCODER + p_i.num_attention_heads = self.num_heads + p_i.hidden_size = self.model_dims + p_i.mlp_hidden_size = self.hidden_dims + assert self.dim_per_head == self.model_dims // self.num_heads + assert self.packed_input == False + assert len(self.moe_layers) == 0 + assert self.ngrammer_tpls is None + + if self.ngrammer_tpls is not None: + if self.ngrammer_tpls[i] is not None: + p_i.ngrammer_tpl = self.ngrammer_tpls[i] + return p_i + + if isinstance(self.transformer_layer_params_tpl, (list, tuple)): + if self.num_layers % len(self.transformer_layer_params_tpl): + raise ValueError('num_layers should be divisible by ' + 'transformer_layer_params_tpl') + + layer_params = [_layer_params(i) for i in range(self.num_layers)] + self.create_children('x_layers', layer_params) + + if self.input_dropout_prob > 0.0: + self.create_child( + 'input_dropout', + pax_fiddle.Config( + stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob + ), + ) + + def __call__(self, + inputs: JTensor, + paddings: JTensor, + segment_mask: Optional[JTensor] = None, + cross_inputs: Optional[JTensor] = None, + cross_paddings: Optional[JTensor] = None, + cross_segment_mask: Optional[JTensor] = None, + segment_pos: Optional[JTensor] = None) -> JTensor: + + if self.packed_input: + assert segment_mask is not None + + if self.use_cross_attention: + assert cross_inputs is not None + assert cross_paddings is not None + if self.packed_input: + assert cross_segment_mask is not None + + attention_mask, cross_attention_mask = transformers.compute_attention_masks_for_fprop( + inputs, + paddings, + self.mask_self_attention, + segment_mask, + cross_inputs, + cross_paddings, + cross_segment_mask, + fold_padding_with_segment_mask=self.fold_padding_with_segment_mask, + ) + + x_out = inputs + if self.input_dropout_prob > 0.0: + x_out = self.input_dropout(x_out) + + attention_mask = 1 - (attention_mask == 0) + attention_mask = attention_mask.astype(jnp.uint8) + + if cross_attention_mask is not None: + cross_attention_mask = 1 - (cross_attention_mask == 0) + cross_attention_mask = cross_attention_mask.astype(jnp.uint8) + + for i in range(self.num_layers): + x_in = x_out + x_out = self.x_layers[i]( + inputs=x_in, + attention_mask=attention_mask, + encoded=cross_inputs, + encoder_decoder_mask=cross_attention_mask) + x_out = checkpoint_name(x_out, 'transformer_layer_out') + return x_out + + +class TransformerEngineHelperBase: + + @staticmethod + def get_stack_transformer(stacked_transformer_p, dtype): + raise NotImplementedError + + @staticmethod + def update_fp8_metas_if_needed(mdl_vars, grads): + raise NotImplementedError + + @staticmethod + def include_fp8_for_grads_if_needed(variables): + raise NotImplementedError + + @staticmethod + def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): + raise NotImplementedError + + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): + raise NotImplementedError + + +class TENotInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def get_stack_transformer(stacked_transformer_p, dtype): + return stacked_transformer_p + + @staticmethod + def update_fp8_metas_if_needed(mdl_vars, grads): + return mdl_vars + + @staticmethod + def include_fp8_for_grads_if_needed(variables): + return variables + + @staticmethod + def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): + return grads + + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): + try: + yield + finally: + pass + + +class TEInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def get_stack_transformer(stacked_transformer_p, dtype): + + assert stacked_transformer_p.cls == transformers.StackedTransformer + + te_stacked_transformer_p = pax_fiddle.Config(StackedTransformer, + use_cross_attention=stacked_transformer_p.use_cross_attention, + mask_self_attention=stacked_transformer_p.mask_self_attention, + num_layers=stacked_transformer_p.num_layers, + model_dims=stacked_transformer_p.model_dims, + hidden_dims=stacked_transformer_p.hidden_dims, + num_heads=stacked_transformer_p.num_heads, + dim_per_head=stacked_transformer_p.dim_per_head, + dropout_prob=stacked_transformer_p.dropout_prob, + atten_dropout_prob=stacked_transformer_p.atten_dropout_prob, + residual_dropout_prob=stacked_transformer_p.residual_dropout_prob, + relu_dropout_prob=stacked_transformer_p.relu_dropout_prob, + residual_droppath_prob=stacked_transformer_p.residual_droppath_prob, + input_dropout_prob=stacked_transformer_p.input_dropout_prob, + gating_func=stacked_transformer_p.gating_func, + unadjusted_expert_capacity_factor=stacked_transformer_p.unadjusted_expert_capacity_factor, + packed_input=stacked_transformer_p.packed_input, + fold_padding_with_segment_mask=stacked_transformer_p.fold_padding_with_segment_mask, + moe_layer_tpl=stacked_transformer_p.moe_layer_tpl, + num_experts=stacked_transformer_p.num_experts, + num_groups=stacked_transformer_p.num_groups, + min_group_size=stacked_transformer_p.min_group_size, + moe_layers=stacked_transformer_p.moe_layers, + ngrammer_tpls=stacked_transformer_p.ngrammer_tpls + ) + + ori_transformer_engine_p = stacked_transformer_p.transformer_layer_params_tpl + + te_stacked_transformer_p.transformer_layer_params_tpl = pax_fiddle.Config(te_praxis.TransformerLayer, + name='transformer_layer', + params_init=stacked_transformer_p.params_init, + dtype=dtype, + hidden_size=stacked_transformer_p.model_dims, + mlp_hidden_size=stacked_transformer_p.hidden_dims, + num_attention_heads=stacked_transformer_p.num_heads, + layernorm_type='layernorm', + layernorm_epsilon=ori_transformer_engine_p.ln_tpl.epsilon, + zero_centered_gamma = True, + hidden_dropout=ori_transformer_engine_p.residual_dropout_prob, + attention_dropout=ori_transformer_engine_p.atten_dropout_prob, + mlp_activations=('gelu',), + use_bias=True, + layer_type=te_praxis.TransformerLayerType.ENCODER, + self_attn_mask_type='causal', + enable_relative_embedding=False, + drop_path=ori_transformer_engine_p.residual_droppath_prob, + scaled_query_init=False, + scale_attn_logits=True, + transpose_batch_sequence=False + ) + + return te_stacked_transformer_p + + @staticmethod + def update_fp8_metas_if_needed(mdl_vars, grads): + FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME + if FP8_COLLECTION_NAME in grads: + mdl_vars[FP8_COLLECTION_NAME] = te.update_fp8_metas(grads)[FP8_COLLECTION_NAME] + return mdl_vars + + @staticmethod + def include_fp8_for_grads_if_needed(variables): + FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME + if FP8_COLLECTION_NAME in variables: + variables[FP8_COLLECTION_NAME] = \ + jax.tree_util.tree_map(lambda x: False, variables[FP8_COLLECTION_NAME]) + return variables + + @staticmethod + def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): + FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME + if FP8_COLLECTION_NAME in grads: + grads[FP8_COLLECTION_NAME] = vars_with_opt[FP8_COLLECTION_NAME].copy() + return grads + + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): + fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID, + amax_history_len=1024, amax_compute_algo='max') + + enable_fp8 = bool(int((os.environ.get("ENABLE_FP8", False)))) + try: + with te.fp8_autocast(enabled=enable_fp8, + fp8_recipe=fp8_recipe, + sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis)): + yield + finally: + pass + + +class TransformerEngineHelper(TransformerEngineHelperBase): + + @staticmethod + def is_enabled_te(): + enable_te = bool(int((os.environ.get("ENABLE_TE", False)))) + return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te) + + @staticmethod + def get_helper(): + if TransformerEngineHelper.is_enabled_te(): + return TEInstalledHelper + return TENotInstalledHelper + + @staticmethod + def get_stack_transformer(stacked_transformer_p, dtype): + return TransformerEngineHelper.get_helper().get_stack_transformer(stacked_transformer_p, dtype) + + @staticmethod + def update_fp8_metas_if_needed(mdl_vars, grads): + return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads) + + @staticmethod + def include_fp8_for_grads_if_needed(variables): + return TransformerEngineHelper.get_helper().include_fp8_for_grads_if_needed(variables) + + @staticmethod + def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): + return TransformerEngineHelper.get_helper().mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt) + + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): + try: + with TransformerEngineHelper.get_helper().fp8_autocast(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis): + yield + finally: + pass diff --git a/paxml/main.py b/paxml/main.py index f03e2d4a4..f087f277a 100644 --- a/paxml/main.py +++ b/paxml/main.py @@ -52,6 +52,7 @@ from paxml import trainer_lib from paxml import tuning_lib from paxml import ml_monitoring +from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper from praxis import pax_fiddle from praxis import py_utils @@ -519,35 +520,36 @@ def create_experiment_config(): ), ) - if FLAGS.exp is not None: - experiment_config = get_experiment(FLAGS.exp)() - elif absl_flags.fdl_flags_supplied(): - # Use the legacy Fiddle flags API to parse command line Fiddle flags. - cfg = absl_flags.create_buildable_from_flags( - module=None, allow_imports=True) - experiment_config = pax_fiddle.build(cfg) - logging.warning( - 'Legacy Fiddle flags API usage detected. Please use the new Fiddle' - ' command line flag `fdl` with various commands to specify the' - ' config and any overrides. Please see' - ' `fiddle/docs/flags_code_lab.md` for more' - ' documentation on Fiddle flags usage.' - ) - elif _FIDDLE_CONFIG.value is not None: - # This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse - # command line Fiddle flags. See - # `fiddle/docs/flags_code_lab.md` for details on the new - # Fiddle flags API. - logging.info( - 'Using pax_fiddle_config from the command line: %s', - _FIDDLE_CONFIG.value, - ) - experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value) - else: - raise app.UsageError( - 'No experiment provided. At least one of --exp, --fdl,' - ' --fdl_config, or --fdl_config_file is required.' - ) + with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'): + if FLAGS.exp is not None: + experiment_config = get_experiment(FLAGS.exp)() + elif absl_flags.fdl_flags_supplied(): + # Use the legacy Fiddle flags API to parse command line Fiddle flags. + cfg = absl_flags.create_buildable_from_flags( + module=None, allow_imports=True) + experiment_config = pax_fiddle.build(cfg) + logging.warning( + 'Legacy Fiddle flags API usage detected. Please use the new Fiddle' + ' command line flag `fdl` with various commands to specify the' + ' config and any overrides. Please see' + ' `fiddle/docs/flags_code_lab.md` for more' + ' documentation on Fiddle flags usage.' + ) + elif _FIDDLE_CONFIG.value is not None: + # This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse + # command line Fiddle flags. See + # `fiddle/docs/flags_code_lab.md` for details on the new + # Fiddle flags API. + logging.info( + 'Using pax_fiddle_config from the command line: %s', + _FIDDLE_CONFIG.value, + ) + experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value) + else: + raise app.UsageError( + 'No experiment provided. At least one of --exp, --fdl,' + ' --fdl_config, or --fdl_config_file is required.' + ) experiment_config.validate() return experiment_config @@ -571,7 +573,6 @@ def _main(argv: Sequence[str]) -> None: startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs, ) - _TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+') if __name__ == '__main__': diff --git a/paxml/tasks_lib.py b/paxml/tasks_lib.py index 010998dd8..4b9028233 100644 --- a/paxml/tasks_lib.py +++ b/paxml/tasks_lib.py @@ -43,6 +43,7 @@ from paxml import io_utils from paxml import learners as learners_lib from paxml import train_states +from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST from praxis import asserts from praxis import base_hyperparams from praxis import base_input @@ -1786,7 +1787,7 @@ def _apply_init_checkpoint_rule( ) # Initialize with a dummy seed var_weight_hparams = ckpt_task.model.abstract_init_with_metadata( - inputs_shape_dtype) + inputs_shape_dtype, mutable=DEFAULT_INIT_MUTABLE_LIST) ckpt_train_state = ckpt_task.create_train_state_padded_shapes( var_weight_hparams) train_state_pspecs = ckpt_task.create_train_state_partition_specs( diff --git a/paxml/trainer_lib.py b/paxml/trainer_lib.py index 6ec25e8c6..034232849 100644 --- a/paxml/trainer_lib.py +++ b/paxml/trainer_lib.py @@ -35,6 +35,7 @@ from paxml import sgf from paxml import tasks_lib from paxml import train_states +from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper, DEFAULT_INIT_MUTABLE_LIST from praxis import asserts from praxis import base_hyperparams from praxis import base_input @@ -167,8 +168,7 @@ def create_train_state_metadata( A TrainStateMetadata instance. """ var_weight_hparams = jax_task.model.abstract_init_with_metadata( - train_shape_dtype, do_eval=do_eval - ) + train_shape_dtype, do_eval=do_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) padded_global_shapes = jax_task.create_train_state_padded_shapes( var_weight_hparams, discard_opt_states=discard_opt_states ) @@ -217,7 +217,8 @@ def write_post_init_model_hparams_file( logging.info('post_init_model_params: %s', params_fpath) job_log_dir.mkdir(parents=True, exist_ok=True) hyper_params = model.abstract_init_with_mdl_config( - train_state_metadata.input_shape_dtype, do_eval=do_eval + train_state_metadata.input_shape_dtype, do_eval=do_eval, + extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST ) with params_fpath.open('w') as params_file: hyper_params_dump = base_hyperparams.nested_struct_to_text(hyper_params) @@ -379,7 +380,8 @@ def initialize_model_state( is_eval_for_init = is_eval if not var_weight_hparams: var_weight_hparams = model.abstract_init_with_metadata( - inputs_shape_dtype, do_eval=is_eval_for_init + inputs_shape_dtype, do_eval=is_eval_for_init, + extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST ) logging.info('init_var prng_seed: %s', init_key) logging.info('var_weight_hparams: %s', var_weight_hparams) @@ -396,7 +398,7 @@ def init_fn(init_key): inputs = jax.tree.map(jnp.zeros_like, inputs_shape_dtype) if model.hparams.fprop_dtype == jnp.bfloat16: inputs = jax.tree.map(_maybe_to_bfloat16, inputs) - return model.init(init_key, inputs) + return model.init(init_key, inputs, mutable=DEFAULT_INIT_MUTABLE_LIST) initial_vars = init_fn(init_key) logging.info('initial_vars: %s', jax.tree.map(jnp.shape, initial_vars)) @@ -809,7 +811,6 @@ def __call__( ) -> tuple[JTensor, sgf.GradAuxInfo]: """Produces losses and grad info by passing the inputs through a model.""" - def _get_default_loss_fn( jax_task: tasks_lib.SingleTask, context_p: base_layer.JaxContext.HParams, @@ -994,14 +995,16 @@ def get_excluded_var_masks( excluded_for_grad = tasks_lib.get_excluded_var_mask_for_grad( var_weight_hparams, learner ) - _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad) + excluded_for_grad_but_fp8_meta = TransformerEngineHelper.include_fp8_for_grads_if_needed(excluded_for_grad.copy()) + + _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad_but_fp8_meta) # Excluded for optimizer states. excluded_for_opt = tasks_lib.get_excluded_var_mask_for_opt( var_weight_hparams, learner, ) - return excluded_for_grad, excluded_for_opt + return excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt def _prepare_tree_data_for_summary(tree): @@ -1090,7 +1093,7 @@ def train_step_single_learner( if not var_weight_hparams: with base_layer.JaxContext.new_context(hparams=context_p): - var_weight_hparams = model.abstract_init_with_metadata(inputs) + var_weight_hparams = model.abstract_init_with_metadata(inputs, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) updated_model_vars = jax_task.maybe_adjust_train_state( # pytype: disable=wrong-arg-types # jax-ndarray step=states.step, mdl_vars=states.mdl_vars, @@ -1100,13 +1103,13 @@ def train_step_single_learner( _, subkey = jax.random.split(prng_key) - excluded_for_grad, excluded_for_opt = get_excluded_var_masks( + excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt = get_excluded_var_masks( var_weight_hparams, learner ) # Construct and call the grad function. if not grad_fn: - grad_fn = _get_default_grad_fn(excluded_for_grad, excluded_for_opt) + grad_fn = _get_default_grad_fn(excluded_for_grad_but_fp8_meta, excluded_for_opt) (weighted_loss, aux_info), grads = grad_fn( loss_fn=_get_default_loss_fn( jax_task=jax_task, @@ -1154,7 +1157,7 @@ def train_step_single_learner( # Make updated non-trainable vars visible to EMA. mdl_vars[NON_TRAINABLE] = fwd_updated_vars[NON_TRAINABLE] excluded_for_learner = jax.tree.map( - lambda eo, eg: eo and eg, excluded_for_opt, excluded_for_grad + lambda eo, eg: eo and eg, excluded_for_opt, excluded_for_grad_but_fp8_meta ) vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt( mdl_vars, excluded_for_learner @@ -1162,6 +1165,10 @@ def train_step_single_learner( wps_with_opt = tasks_lib.filter_vars_for_grad_or_opt( var_weight_hparams, excluded_for_learner ) + + mdl_vars = TransformerEngineHelper.update_fp8_metas_if_needed(mdl_vars, grads) + grads = TransformerEngineHelper.mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt) + transformed_grads, new_opt_states = learner.update_states( grads, states.opt_states[0], vars_with_opt, wps_with_opt ) @@ -1197,6 +1204,7 @@ def train_step_single_learner( states.mdl_vars, mdl_vars, ) + new_states = states.new_state( mdl_vars=mdl_vars, opt_states=[new_opt_states], extra_state=() ) @@ -1300,7 +1308,7 @@ def eval_step_single_learner( var_weight_hparams = model.abstract_init_with_metadata( inputs, do_eval=not jax_task.hparams.train.always_use_train_for_model_init, - ) + extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) if fprop_dtype == jnp.float32: pass @@ -1554,7 +1562,7 @@ def initialize_partitioned_model_states( model = jax_task.model if not var_weight_hparams: var_weight_hparams = model.abstract_init_with_metadata( - global_input_shapes, do_eval=is_eval + global_input_shapes, do_eval=is_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST ) train_state_partition_specs = ( From 3a37cdae1ab420890284647bda008615c9069ab2 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 27 Sep 2023 10:46:53 +0800 Subject: [PATCH 02/12] Adding dropout support when enabling TE. --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py index d44ca6736..2b9dba4d8 100644 --- a/paxml/contrib/gpu/scripts_gpu/te_helper.py +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -59,6 +59,16 @@ def _layer_params(i): p_i.num_attention_heads = self.num_heads p_i.hidden_size = self.model_dims p_i.mlp_hidden_size = self.hidden_dims + + p_i.dropout_rng_name = base_layer.RANDOM + p_i.attention_dropout = self.atten_dropout_prob or self.dropout_prob + p_i.hidden_dropout = self.residual_dropout_prob or self.dropout_prob + p_i.intermediate_dropout = self.relu_dropout_prob or self.dropout_prob + if self.residual_droppath_prob > 0.0: + p_i.drop_path = ( + self.residual_droppath_prob * i / max(1, self.num_layers) + ) + assert self.dim_per_head == self.model_dims // self.num_heads assert self.packed_input == False assert len(self.moe_layers) == 0 From 9fc82c91ac28cc4e7fe6b6b071631c07cbbdef22 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 24 Oct 2023 10:30:27 +0800 Subject: [PATCH 03/12] Set deterministic=True for inference. --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py index 2b9dba4d8..ef20305e6 100644 --- a/paxml/contrib/gpu/scripts_gpu/te_helper.py +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -141,7 +141,8 @@ def __call__(self, inputs=x_in, attention_mask=attention_mask, encoded=cross_inputs, - encoder_decoder_mask=cross_attention_mask) + encoder_decoder_mask=cross_attention_mask, + deterministic=self.do_eval) x_out = checkpoint_name(x_out, 'transformer_layer_out') return x_out From ea6dd6c65b4dcf965f2c1c5e2b502f80a9142ed1 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 2 Nov 2023 22:04:58 -0700 Subject: [PATCH 04/12] Fix the excluded list for excluded_for_learner Signed-off-by: Reese Wang --- paxml/trainer_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paxml/trainer_lib.py b/paxml/trainer_lib.py index 034232849..2e9bfd6f4 100644 --- a/paxml/trainer_lib.py +++ b/paxml/trainer_lib.py @@ -1157,7 +1157,7 @@ def train_step_single_learner( # Make updated non-trainable vars visible to EMA. mdl_vars[NON_TRAINABLE] = fwd_updated_vars[NON_TRAINABLE] excluded_for_learner = jax.tree.map( - lambda eo, eg: eo and eg, excluded_for_opt, excluded_for_grad_but_fp8_meta + lambda eo, eg: eo and eg, excluded_for_opt, excluded_for_grad ) vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt( mdl_vars, excluded_for_learner From f740f17fbb13b4262238736c7d678dc8bcea1f1f Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 7 Nov 2023 11:21:53 +0800 Subject: [PATCH 05/12] Adapting to TE/JAX/Custom_partitioning. --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py index ef20305e6..fed160115 100644 --- a/paxml/contrib/gpu/scripts_gpu/te_helper.py +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -262,7 +262,7 @@ def get_stack_transformer(stacked_transformer_p, dtype): def update_fp8_metas_if_needed(mdl_vars, grads): FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME if FP8_COLLECTION_NAME in grads: - mdl_vars[FP8_COLLECTION_NAME] = te.update_fp8_metas(grads)[FP8_COLLECTION_NAME] + mdl_vars[FP8_COLLECTION_NAME] = grads[FP8_COLLECTION_NAME] return mdl_vars @staticmethod @@ -290,7 +290,9 @@ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="dat try: with te.fp8_autocast(enabled=enable_fp8, fp8_recipe=fp8_recipe, - sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis)): + mesh_resource=te.MeshResource(dp_resource=dp_mesh_axis, + tp_resource=tp_mesh_axis, + fsdp_resource=fsdp_mesh_axis)): yield finally: pass From 03ef6221821896155b78b8a193df2091b096252d Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 7 Nov 2023 15:14:25 +0800 Subject: [PATCH 06/12] Adding TE-compatiable PipelinedTransformer --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 109 +++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py index fed160115..5914e54b9 100644 --- a/paxml/contrib/gpu/scripts_gpu/te_helper.py +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -31,6 +31,7 @@ LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] JTensor = pytypes.JTensor + class StackedTransformer(transformers.StackedTransformer): """A mirror of StackedTransformer layers in Praxis.""" @@ -147,12 +148,92 @@ def __call__(self, return x_out +class PipelinedTransformer(transformers.PipelinedTransformer): + """A mirror of PipelinedTransformer in Praxis""" + + def __call__( + self, + inputs: JTensor, + paddings: JTensor, + segment_mask: JTensor | None = None, + cross_inputs: JTensor | None = None, + cross_paddings: JTensor | None = None, + cross_segment_mask: JTensor | None = None, + segment_pos: JTensor | None = None, + ) -> JTensor: + + rules = te_flax.extend_logical_axis_rules(tuple()) + batch_mapping = rules[0] + hidden_tp_mapping = rules[4] + # [Batch, Seqlen, Hidden] + bld_mapping = [batch_mapping, None, hidden_tp_mapping] + + if not self.stream_io: + # Annotate the inputs before the pipeline to prevent unexpected + # propagation from earlier layers. + inputs = base_layer.maybe_shard(inputs, bld_mapping, self.mesh_axis_names) + if bld_mapping is not None: + # Annotate other broadcast inputs. + paddings = base_layer.maybe_shard( + paddings, bld_mapping[:-1], self.mesh_axis_names + ) + + # For cross inputs, we only specify the batch dim sharding. + def _shard_batch_dim_only(x): + return base_layer.maybe_shard( + x, + [bld_mapping[0]] + [-1] * (x.ndim - 1), + self.mesh_axis_names, + unconstrained_dims=range(1, x.ndim), + ) + + if segment_mask is not None: + segment_mask = _shard_batch_dim_only(segment_mask) + if cross_inputs is not None: + cross_inputs = _shard_batch_dim_only(cross_inputs) + if cross_paddings is not None: + cross_paddings = _shard_batch_dim_only(cross_paddings) + if cross_segment_mask is not None: + cross_segment_mask = _shard_batch_dim_only(cross_segment_mask) + + if segment_pos is not None: + segment_pos = base_layer.maybe_shard( + segment_pos, bld_mapping[:-1], self.mesh_axis_names + ) + + outputs = self.pipeline( + inputs, + paddings, + segment_mask=segment_mask, + cross_inputs=cross_inputs, + cross_paddings=cross_paddings, + cross_segment_mask=cross_segment_mask, + segment_pos=segment_pos, + ) + + if not self.stream_io: + outputs = base_layer.maybe_shard( + outputs, bld_mapping, self.mesh_axis_names + ) + + outputs = base_layer.maybe_shard( + outputs, + self.activation_split_dims_mapping.final_out, + self.mesh_axis_names, + ) + return outputs + + class TransformerEngineHelperBase: @staticmethod def get_stack_transformer(stacked_transformer_p, dtype): raise NotImplementedError + @staticmethod + def get_pipeline_transformer(pipeline_transformer_p): + raise NotImplementedError + @staticmethod def update_fp8_metas_if_needed(mdl_vars, grads): raise NotImplementedError @@ -177,6 +258,10 @@ class TENotInstalledHelper(TransformerEngineHelperBase): def get_stack_transformer(stacked_transformer_p, dtype): return stacked_transformer_p + @staticmethod + def get_pipeline_transformer(pipeline_transformer_p): + return pipeline_transformer_p + @staticmethod def update_fp8_metas_if_needed(mdl_vars, grads): return mdl_vars @@ -258,6 +343,26 @@ def get_stack_transformer(stacked_transformer_p, dtype): return te_stacked_transformer_p + @staticmethod + def get_pipeline_transformer(pipeline_transformer_p): + + assert pipeline_transformer_p.cls == transformers.PipelinedTransformer + + te_pipeline_transformer_p = pax_fiddle.Config(PipelinedTransformer, + pipeline_stage=pipeline_transformer_p.pipeline_stage, + circular_repeat=pipeline_transformer_p.circular_repeat, + num_pipeline_stages=pipeline_transformer_p.num_pipeline_stages, + num_pipeline_microbatches=pipeline_transformer_p.num_pipeline_microbatches, + pipeline_microbatch_size=pipeline_transformer_p.pipeline_microbatch_size, + stream_io=pipeline_transformer_p.stream_io, + pipeline_broadcast_inputs=pipeline_transformer_p.pipeline_broadcast_inputs, + checkpoint_policy=pipeline_transformer_p.checkpoint_policy, + enable_async_circular_transfer=pipeline_transformer_p.enable_async_circular_transfer, + bf16_accum_in_fp32=pipeline_transformer_p.bf16_accum_in_fp32 + ) + + return te_pipeline_transformer_p + @staticmethod def update_fp8_metas_if_needed(mdl_vars, grads): FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME @@ -315,6 +420,10 @@ def get_helper(): def get_stack_transformer(stacked_transformer_p, dtype): return TransformerEngineHelper.get_helper().get_stack_transformer(stacked_transformer_p, dtype) + @staticmethod + def get_pipeline_transformer(pipeline_transformer_p): + return TransformerEngineHelper.get_helper().get_pipeline_transformer(pipeline_transformer_p) + @staticmethod def update_fp8_metas_if_needed(mdl_vars, grads): return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads) From 2e2081ec97e73d9c212dcfb83713a87eea374396 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 8 Nov 2023 10:06:49 +0800 Subject: [PATCH 07/12] Apply OWG to TE's FP8 meta --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 59 ---------------------- paxml/trainer_lib.py | 12 ++--- 2 files changed, 4 insertions(+), 67 deletions(-) diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py index 5914e54b9..fd482dfd1 100644 --- a/paxml/contrib/gpu/scripts_gpu/te_helper.py +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -2,7 +2,6 @@ from contextlib import contextmanager from typing import Optional, Sequence -import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name from praxis import base_layer @@ -234,18 +233,6 @@ def get_stack_transformer(stacked_transformer_p, dtype): def get_pipeline_transformer(pipeline_transformer_p): raise NotImplementedError - @staticmethod - def update_fp8_metas_if_needed(mdl_vars, grads): - raise NotImplementedError - - @staticmethod - def include_fp8_for_grads_if_needed(variables): - raise NotImplementedError - - @staticmethod - def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): - raise NotImplementedError - @staticmethod @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): @@ -262,18 +249,6 @@ def get_stack_transformer(stacked_transformer_p, dtype): def get_pipeline_transformer(pipeline_transformer_p): return pipeline_transformer_p - @staticmethod - def update_fp8_metas_if_needed(mdl_vars, grads): - return mdl_vars - - @staticmethod - def include_fp8_for_grads_if_needed(variables): - return variables - - @staticmethod - def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): - return grads - @staticmethod @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): @@ -363,28 +338,6 @@ def get_pipeline_transformer(pipeline_transformer_p): return te_pipeline_transformer_p - @staticmethod - def update_fp8_metas_if_needed(mdl_vars, grads): - FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME - if FP8_COLLECTION_NAME in grads: - mdl_vars[FP8_COLLECTION_NAME] = grads[FP8_COLLECTION_NAME] - return mdl_vars - - @staticmethod - def include_fp8_for_grads_if_needed(variables): - FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME - if FP8_COLLECTION_NAME in variables: - variables[FP8_COLLECTION_NAME] = \ - jax.tree_util.tree_map(lambda x: False, variables[FP8_COLLECTION_NAME]) - return variables - - @staticmethod - def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): - FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME - if FP8_COLLECTION_NAME in grads: - grads[FP8_COLLECTION_NAME] = vars_with_opt[FP8_COLLECTION_NAME].copy() - return grads - @staticmethod @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): @@ -424,18 +377,6 @@ def get_stack_transformer(stacked_transformer_p, dtype): def get_pipeline_transformer(pipeline_transformer_p): return TransformerEngineHelper.get_helper().get_pipeline_transformer(pipeline_transformer_p) - @staticmethod - def update_fp8_metas_if_needed(mdl_vars, grads): - return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads) - - @staticmethod - def include_fp8_for_grads_if_needed(variables): - return TransformerEngineHelper.get_helper().include_fp8_for_grads_if_needed(variables) - - @staticmethod - def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): - return TransformerEngineHelper.get_helper().mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt) - @staticmethod @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): diff --git a/paxml/trainer_lib.py b/paxml/trainer_lib.py index 2e9bfd6f4..270fb3dc6 100644 --- a/paxml/trainer_lib.py +++ b/paxml/trainer_lib.py @@ -995,16 +995,15 @@ def get_excluded_var_masks( excluded_for_grad = tasks_lib.get_excluded_var_mask_for_grad( var_weight_hparams, learner ) - excluded_for_grad_but_fp8_meta = TransformerEngineHelper.include_fp8_for_grads_if_needed(excluded_for_grad.copy()) - _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad_but_fp8_meta) + _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad) # Excluded for optimizer states. excluded_for_opt = tasks_lib.get_excluded_var_mask_for_opt( var_weight_hparams, learner, ) - return excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt + return excluded_for_grad, excluded_for_opt def _prepare_tree_data_for_summary(tree): @@ -1103,13 +1102,13 @@ def train_step_single_learner( _, subkey = jax.random.split(prng_key) - excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt = get_excluded_var_masks( + excluded_for_grad, excluded_for_opt = get_excluded_var_masks( var_weight_hparams, learner ) # Construct and call the grad function. if not grad_fn: - grad_fn = _get_default_grad_fn(excluded_for_grad_but_fp8_meta, excluded_for_opt) + grad_fn = _get_default_grad_fn(excluded_for_grad, excluded_for_opt) (weighted_loss, aux_info), grads = grad_fn( loss_fn=_get_default_loss_fn( jax_task=jax_task, @@ -1166,9 +1165,6 @@ def train_step_single_learner( var_weight_hparams, excluded_for_learner ) - mdl_vars = TransformerEngineHelper.update_fp8_metas_if_needed(mdl_vars, grads) - grads = TransformerEngineHelper.mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt) - transformed_grads, new_opt_states = learner.update_states( grads, states.opt_states[0], vars_with_opt, wps_with_opt ) From 9df77b4399f983e71d276ddf487abe588a13191b Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 15 Nov 2023 14:43:17 +0800 Subject: [PATCH 08/12] Remove Praxis related setup (Moving to Praxis TE/Patch) --- paxml/contrib/gpu/scripts_gpu/configs.py | 9 - paxml/contrib/gpu/scripts_gpu/te_helper.py | 315 --------------------- 2 files changed, 324 deletions(-) diff --git a/paxml/contrib/gpu/scripts_gpu/configs.py b/paxml/contrib/gpu/scripts_gpu/configs.py index 16d412587..7d50a525a 100644 --- a/paxml/contrib/gpu/scripts_gpu/configs.py +++ b/paxml/contrib/gpu/scripts_gpu/configs.py @@ -178,15 +178,6 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: transformer_layer_p = stacked_p.transformer_layer_params_tpl transformer_layer_p.ln_tpl.reductions_in_fp32 = True transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True - else: - stacked_p = TransformerEngineHelper.get_stack_transformer( - stacked_p, jnp.dtype(self.FPROP_DTYPE)) - if issubclass(fdl.get_callable(model_p.lm_tpl.stacked_transformer_tpl), - transformers.StackedTransformerRepeated): - model_p.lm_tpl.stacked_transformer_tpl.block = stacked_p - else: - model_p.lm_tpl.stacked_transformer_tpl = stacked_p - model_p.params_init = WeightInit.Gaussian(self.INIT_STD) softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py index fd482dfd1..b2712585e 100644 --- a/paxml/contrib/gpu/scripts_gpu/te_helper.py +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -1,238 +1,17 @@ import os from contextlib import contextmanager -from typing import Optional, Sequence - -import jax.numpy as jnp -from jax.ad_checkpoint import checkpoint_name -from praxis import base_layer -from praxis import pax_fiddle -from praxis import pytypes -from praxis.layers import transformers -from praxis.layers import stochastics try: import transformer_engine.jax as te - import transformer_engine.jax.flax as te_flax - import transformer_engine.jax.praxis as te_praxis from transformer_engine.common import recipe _IS_TRANSFORMER_ENGINE_INSTALLED = True - DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME] - import praxis.layers.repeats as praxis_repeat - # This is to make Repeat module correctly generate collections we need. - praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes - te.fp8.FP8Helper.FP8_COLLECTION_NAME:0}) except ModuleNotFoundError as e: _IS_TRANSFORMER_ENGINE_INSTALLED = False - DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST - - -LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] -JTensor = pytypes.JTensor - - -class StackedTransformer(transformers.StackedTransformer): - """A mirror of StackedTransformer layers in Praxis.""" - - def setup(self) -> None: - - assert self.num_layers > 0 - assert self.model_dims > 0 - assert self.hidden_dims > 0 - assert self.num_heads > 0 - assert 0.0 <= self.dropout_prob < 1.0 - assert 0.0 <= self.input_dropout_prob < 1.0 - - def _layer_params(i): - """Construct i-th layer params.""" - if isinstance(self.transformer_layer_params_tpl, Sequence): - factor = self.num_layers // len(self.transformer_layer_params_tpl) - ii = i // factor - p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii]) - else: - p_i = self._clone_layer_params(self.transformer_layer_params_tpl) - p_i.name = f'layer_{i}' - - p_i.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple()) - p_i.layer_type = te_praxis.TransformerLayerType.DECODER if self.use_cross_attention \ - else te_praxis.TransformerLayerType.ENCODER - p_i.num_attention_heads = self.num_heads - p_i.hidden_size = self.model_dims - p_i.mlp_hidden_size = self.hidden_dims - - p_i.dropout_rng_name = base_layer.RANDOM - p_i.attention_dropout = self.atten_dropout_prob or self.dropout_prob - p_i.hidden_dropout = self.residual_dropout_prob or self.dropout_prob - p_i.intermediate_dropout = self.relu_dropout_prob or self.dropout_prob - if self.residual_droppath_prob > 0.0: - p_i.drop_path = ( - self.residual_droppath_prob * i / max(1, self.num_layers) - ) - - assert self.dim_per_head == self.model_dims // self.num_heads - assert self.packed_input == False - assert len(self.moe_layers) == 0 - assert self.ngrammer_tpls is None - - if self.ngrammer_tpls is not None: - if self.ngrammer_tpls[i] is not None: - p_i.ngrammer_tpl = self.ngrammer_tpls[i] - return p_i - - if isinstance(self.transformer_layer_params_tpl, (list, tuple)): - if self.num_layers % len(self.transformer_layer_params_tpl): - raise ValueError('num_layers should be divisible by ' - 'transformer_layer_params_tpl') - - layer_params = [_layer_params(i) for i in range(self.num_layers)] - self.create_children('x_layers', layer_params) - - if self.input_dropout_prob > 0.0: - self.create_child( - 'input_dropout', - pax_fiddle.Config( - stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob - ), - ) - - def __call__(self, - inputs: JTensor, - paddings: JTensor, - segment_mask: Optional[JTensor] = None, - cross_inputs: Optional[JTensor] = None, - cross_paddings: Optional[JTensor] = None, - cross_segment_mask: Optional[JTensor] = None, - segment_pos: Optional[JTensor] = None) -> JTensor: - - if self.packed_input: - assert segment_mask is not None - - if self.use_cross_attention: - assert cross_inputs is not None - assert cross_paddings is not None - if self.packed_input: - assert cross_segment_mask is not None - - attention_mask, cross_attention_mask = transformers.compute_attention_masks_for_fprop( - inputs, - paddings, - self.mask_self_attention, - segment_mask, - cross_inputs, - cross_paddings, - cross_segment_mask, - fold_padding_with_segment_mask=self.fold_padding_with_segment_mask, - ) - - x_out = inputs - if self.input_dropout_prob > 0.0: - x_out = self.input_dropout(x_out) - - attention_mask = 1 - (attention_mask == 0) - attention_mask = attention_mask.astype(jnp.uint8) - - if cross_attention_mask is not None: - cross_attention_mask = 1 - (cross_attention_mask == 0) - cross_attention_mask = cross_attention_mask.astype(jnp.uint8) - - for i in range(self.num_layers): - x_in = x_out - x_out = self.x_layers[i]( - inputs=x_in, - attention_mask=attention_mask, - encoded=cross_inputs, - encoder_decoder_mask=cross_attention_mask, - deterministic=self.do_eval) - x_out = checkpoint_name(x_out, 'transformer_layer_out') - return x_out - - -class PipelinedTransformer(transformers.PipelinedTransformer): - """A mirror of PipelinedTransformer in Praxis""" - - def __call__( - self, - inputs: JTensor, - paddings: JTensor, - segment_mask: JTensor | None = None, - cross_inputs: JTensor | None = None, - cross_paddings: JTensor | None = None, - cross_segment_mask: JTensor | None = None, - segment_pos: JTensor | None = None, - ) -> JTensor: - - rules = te_flax.extend_logical_axis_rules(tuple()) - batch_mapping = rules[0] - hidden_tp_mapping = rules[4] - # [Batch, Seqlen, Hidden] - bld_mapping = [batch_mapping, None, hidden_tp_mapping] - - if not self.stream_io: - # Annotate the inputs before the pipeline to prevent unexpected - # propagation from earlier layers. - inputs = base_layer.maybe_shard(inputs, bld_mapping, self.mesh_axis_names) - if bld_mapping is not None: - # Annotate other broadcast inputs. - paddings = base_layer.maybe_shard( - paddings, bld_mapping[:-1], self.mesh_axis_names - ) - - # For cross inputs, we only specify the batch dim sharding. - def _shard_batch_dim_only(x): - return base_layer.maybe_shard( - x, - [bld_mapping[0]] + [-1] * (x.ndim - 1), - self.mesh_axis_names, - unconstrained_dims=range(1, x.ndim), - ) - - if segment_mask is not None: - segment_mask = _shard_batch_dim_only(segment_mask) - if cross_inputs is not None: - cross_inputs = _shard_batch_dim_only(cross_inputs) - if cross_paddings is not None: - cross_paddings = _shard_batch_dim_only(cross_paddings) - if cross_segment_mask is not None: - cross_segment_mask = _shard_batch_dim_only(cross_segment_mask) - - if segment_pos is not None: - segment_pos = base_layer.maybe_shard( - segment_pos, bld_mapping[:-1], self.mesh_axis_names - ) - - outputs = self.pipeline( - inputs, - paddings, - segment_mask=segment_mask, - cross_inputs=cross_inputs, - cross_paddings=cross_paddings, - cross_segment_mask=cross_segment_mask, - segment_pos=segment_pos, - ) - - if not self.stream_io: - outputs = base_layer.maybe_shard( - outputs, bld_mapping, self.mesh_axis_names - ) - - outputs = base_layer.maybe_shard( - outputs, - self.activation_split_dims_mapping.final_out, - self.mesh_axis_names, - ) - return outputs class TransformerEngineHelperBase: - @staticmethod - def get_stack_transformer(stacked_transformer_p, dtype): - raise NotImplementedError - - @staticmethod - def get_pipeline_transformer(pipeline_transformer_p): - raise NotImplementedError - @staticmethod @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): @@ -241,14 +20,6 @@ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="dat class TENotInstalledHelper(TransformerEngineHelperBase): - @staticmethod - def get_stack_transformer(stacked_transformer_p, dtype): - return stacked_transformer_p - - @staticmethod - def get_pipeline_transformer(pipeline_transformer_p): - return pipeline_transformer_p - @staticmethod @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): @@ -260,84 +31,6 @@ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="dat class TEInstalledHelper(TransformerEngineHelperBase): - @staticmethod - def get_stack_transformer(stacked_transformer_p, dtype): - - assert stacked_transformer_p.cls == transformers.StackedTransformer - - te_stacked_transformer_p = pax_fiddle.Config(StackedTransformer, - use_cross_attention=stacked_transformer_p.use_cross_attention, - mask_self_attention=stacked_transformer_p.mask_self_attention, - num_layers=stacked_transformer_p.num_layers, - model_dims=stacked_transformer_p.model_dims, - hidden_dims=stacked_transformer_p.hidden_dims, - num_heads=stacked_transformer_p.num_heads, - dim_per_head=stacked_transformer_p.dim_per_head, - dropout_prob=stacked_transformer_p.dropout_prob, - atten_dropout_prob=stacked_transformer_p.atten_dropout_prob, - residual_dropout_prob=stacked_transformer_p.residual_dropout_prob, - relu_dropout_prob=stacked_transformer_p.relu_dropout_prob, - residual_droppath_prob=stacked_transformer_p.residual_droppath_prob, - input_dropout_prob=stacked_transformer_p.input_dropout_prob, - gating_func=stacked_transformer_p.gating_func, - unadjusted_expert_capacity_factor=stacked_transformer_p.unadjusted_expert_capacity_factor, - packed_input=stacked_transformer_p.packed_input, - fold_padding_with_segment_mask=stacked_transformer_p.fold_padding_with_segment_mask, - moe_layer_tpl=stacked_transformer_p.moe_layer_tpl, - num_experts=stacked_transformer_p.num_experts, - num_groups=stacked_transformer_p.num_groups, - min_group_size=stacked_transformer_p.min_group_size, - moe_layers=stacked_transformer_p.moe_layers, - ngrammer_tpls=stacked_transformer_p.ngrammer_tpls - ) - - ori_transformer_engine_p = stacked_transformer_p.transformer_layer_params_tpl - - te_stacked_transformer_p.transformer_layer_params_tpl = pax_fiddle.Config(te_praxis.TransformerLayer, - name='transformer_layer', - params_init=stacked_transformer_p.params_init, - dtype=dtype, - hidden_size=stacked_transformer_p.model_dims, - mlp_hidden_size=stacked_transformer_p.hidden_dims, - num_attention_heads=stacked_transformer_p.num_heads, - layernorm_type='layernorm', - layernorm_epsilon=ori_transformer_engine_p.ln_tpl.epsilon, - zero_centered_gamma = True, - hidden_dropout=ori_transformer_engine_p.residual_dropout_prob, - attention_dropout=ori_transformer_engine_p.atten_dropout_prob, - mlp_activations=('gelu',), - use_bias=True, - layer_type=te_praxis.TransformerLayerType.ENCODER, - self_attn_mask_type='causal', - enable_relative_embedding=False, - drop_path=ori_transformer_engine_p.residual_droppath_prob, - scaled_query_init=False, - scale_attn_logits=True, - transpose_batch_sequence=False - ) - - return te_stacked_transformer_p - - @staticmethod - def get_pipeline_transformer(pipeline_transformer_p): - - assert pipeline_transformer_p.cls == transformers.PipelinedTransformer - - te_pipeline_transformer_p = pax_fiddle.Config(PipelinedTransformer, - pipeline_stage=pipeline_transformer_p.pipeline_stage, - circular_repeat=pipeline_transformer_p.circular_repeat, - num_pipeline_stages=pipeline_transformer_p.num_pipeline_stages, - num_pipeline_microbatches=pipeline_transformer_p.num_pipeline_microbatches, - pipeline_microbatch_size=pipeline_transformer_p.pipeline_microbatch_size, - stream_io=pipeline_transformer_p.stream_io, - pipeline_broadcast_inputs=pipeline_transformer_p.pipeline_broadcast_inputs, - checkpoint_policy=pipeline_transformer_p.checkpoint_policy, - enable_async_circular_transfer=pipeline_transformer_p.enable_async_circular_transfer, - bf16_accum_in_fp32=pipeline_transformer_p.bf16_accum_in_fp32 - ) - - return te_pipeline_transformer_p - @staticmethod @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): @@ -369,14 +62,6 @@ def get_helper(): return TEInstalledHelper return TENotInstalledHelper - @staticmethod - def get_stack_transformer(stacked_transformer_p, dtype): - return TransformerEngineHelper.get_helper().get_stack_transformer(stacked_transformer_p, dtype) - - @staticmethod - def get_pipeline_transformer(pipeline_transformer_p): - return TransformerEngineHelper.get_helper().get_pipeline_transformer(pipeline_transformer_p) - @staticmethod @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): From 3297966c11624d3f692f3450f4eb1acd28b3964e Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 15 Nov 2023 14:51:14 +0800 Subject: [PATCH 09/12] Fix missing DEFAULT_INIT_MUTABLE_LIST --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py index b2712585e..cbac7cf66 100644 --- a/paxml/contrib/gpu/scripts_gpu/te_helper.py +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -1,13 +1,17 @@ import os from contextlib import contextmanager +from praxis import base_layer + try: import transformer_engine.jax as te from transformer_engine.common import recipe _IS_TRANSFORMER_ENGINE_INSTALLED = True + DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME] except ModuleNotFoundError as e: _IS_TRANSFORMER_ENGINE_INSTALLED = False + DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST class TransformerEngineHelperBase: From 65b00553378aead4b083e7154540b6fd0a3a8f4e Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 12 Feb 2024 10:22:15 -0800 Subject: [PATCH 10/12] Revert mutable kwarg in abstract_init_with_metadata in init checkpoint rule --- paxml/tasks_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paxml/tasks_lib.py b/paxml/tasks_lib.py index 4b9028233..e210d6372 100644 --- a/paxml/tasks_lib.py +++ b/paxml/tasks_lib.py @@ -1787,7 +1787,7 @@ def _apply_init_checkpoint_rule( ) # Initialize with a dummy seed var_weight_hparams = ckpt_task.model.abstract_init_with_metadata( - inputs_shape_dtype, mutable=DEFAULT_INIT_MUTABLE_LIST) + inputs_shape_dtype, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) ckpt_train_state = ckpt_task.create_train_state_padded_shapes( var_weight_hparams) train_state_pspecs = ckpt_task.create_train_state_partition_specs( From 4516dbb067a28677ceae31411f8003d021b9e8e3 Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Fri, 12 Apr 2024 20:14:01 +0000 Subject: [PATCH 11/12] Support FM32 to OWG parameters. Signed-off-by: Ming Huang --- paxml/trainer_lib.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/paxml/trainer_lib.py b/paxml/trainer_lib.py index 270fb3dc6..a9afc9e60 100644 --- a/paxml/trainer_lib.py +++ b/paxml/trainer_lib.py @@ -26,6 +26,7 @@ from etils import epath import fiddle as fdl from flax import struct as flax_struct +from flax.linen.fp8_ops import fm32 import jax from jax import numpy as jnp from jax.experimental import pjit @@ -35,7 +36,7 @@ from paxml import sgf from paxml import tasks_lib from paxml import train_states -from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper, DEFAULT_INIT_MUTABLE_LIST +from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST from praxis import asserts from praxis import base_hyperparams from praxis import base_input @@ -804,6 +805,21 @@ def _default_apply_fn( ) +def _maybe_to_fm32_vars(mdl_vars, var_weight_hparams): + asserts.assert_same_structure(mdl_vars, var_weight_hparams) + + def _maybe_fm32_var_fn(var, var_param): + if base_layer.var_overwrite_with_gradient(var_param): + return jax.lax.convert_element_type(var, fm32) + else: + return var + + is_leaf = lambda x: not isinstance(x, (tuple, dict, list)) + return jax.tree_util.tree_map( + _maybe_fm32_var_fn, mdl_vars, var_weight_hparams, is_leaf=is_leaf + ) + + class LossFnProtocol(Protocol): def __call__( @@ -834,6 +850,8 @@ def _loss_fn( else: assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.') + mdl_vars = _maybe_to_fm32_vars(mdl_vars, var_weight_hparams) + with base_layer.JaxContext.new_context(hparams=context_p): k1, k2, k3 = jax.random.split(prng_key, 3) (metrics, per_example_output), updated_vars = apply_fn( From 9e625dc25f441ca3bc408000d94b5ec9c8ded1d4 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 29 Oct 2024 09:33:04 -0700 Subject: [PATCH 12/12] fix OOM with TE Signed-off-by: ashors1 --- paxml/main.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/paxml/main.py b/paxml/main.py index f087f277a..c0173dfda 100644 --- a/paxml/main.py +++ b/paxml/main.py @@ -551,7 +551,7 @@ def create_experiment_config(): ' --fdl_config, or --fdl_config_file is required.' ) - experiment_config.validate() + experiment_config.validate() return experiment_config @@ -567,11 +567,13 @@ def _main(argv: Sequence[str]) -> None: with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.INITIALIZE_SETUP): experiment_config = create_experiment_config() - run( - experiment_config=experiment_config, - enable_checkpoint_saving=FLAGS.enable_checkpoint_saving, - startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs, - ) + with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'): + run( + experiment_config=experiment_config, + enable_checkpoint_saving=FLAGS.enable_checkpoint_saving, + startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs, + ) + _TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+')