diff --git a/paxml/contrib/gpu/scripts_gpu/configs.py b/paxml/contrib/gpu/scripts_gpu/configs.py index cb5ad5210..7d50a525a 100644 --- a/paxml/contrib/gpu/scripts_gpu/configs.py +++ b/paxml/contrib/gpu/scripts_gpu/configs.py @@ -28,6 +28,7 @@ from paxml.contrib.gpu.scripts_gpu.tasks import LambadaDataset from paxml.contrib.gpu.scripts_gpu.tasks import PileUnsupervisedDataset from paxml.tasks.lm.model_params import maybe_setup_moe_params +from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper from paxml.tasks.lm.params.c4 import TransformerLmSpmdAdam from paxml.tasks.lm.params.lm_cloud import SyntheticDataset from praxis import base_layer @@ -116,7 +117,7 @@ class GPT126MBase(TransformerLmSpmdAdam): MAX_SEQ_LEN = 2048 VOCAB_SIZE = 50304 - PACKED_INPUT = True + PACKED_INPUT = False PERCORE_BATCH_SIZE = 4 NUM_LAYERS = 12 @@ -171,10 +172,12 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: fdl.get_callable(stacked_p), transformers.StackedTransformerRepeated ): stacked_p = stacked_p.block - transformer_layer_p = stacked_p.transformer_layer_params_tpl - transformer_layer_p.ln_tpl.reductions_in_fp32 = True - transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True + task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True + if not TransformerEngineHelper.is_enabled_te(): + transformer_layer_p = stacked_p.transformer_layer_params_tpl + transformer_layer_p.ln_tpl.reductions_in_fp32 = True + transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True model_p.params_init = WeightInit.Gaussian(self.INIT_STD) softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) @@ -239,7 +242,7 @@ class GPT175BBase(GPT126MBase): # Known as MLP_DIM in t5x HIDDEN_DIMS = MODEL_DIMS * 4 # Defaults to MODEL_DIMS // NUM_HEADS. - DIMS_PER_HEAD = None + DIMS_PER_HEAD = 128 # Known as NUM_EMBEDDINGS in t5x VOCAB_SIZE = 50257 USE_REPEATED_LAYER = True diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py new file mode 100644 index 000000000..cbac7cf66 --- /dev/null +++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py @@ -0,0 +1,76 @@ +import os +from contextlib import contextmanager + +from praxis import base_layer + +try: + import transformer_engine.jax as te + from transformer_engine.common import recipe + _IS_TRANSFORMER_ENGINE_INSTALLED = True + DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME] + +except ModuleNotFoundError as e: + _IS_TRANSFORMER_ENGINE_INSTALLED = False + DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + + +class TransformerEngineHelperBase: + + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): + raise NotImplementedError + + +class TENotInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): + try: + yield + finally: + pass + + +class TEInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): + fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID, + amax_history_len=1024, amax_compute_algo='max') + + enable_fp8 = bool(int((os.environ.get("ENABLE_FP8", False)))) + try: + with te.fp8_autocast(enabled=enable_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(dp_resource=dp_mesh_axis, + tp_resource=tp_mesh_axis, + fsdp_resource=fsdp_mesh_axis)): + yield + finally: + pass + + +class TransformerEngineHelper(TransformerEngineHelperBase): + + @staticmethod + def is_enabled_te(): + enable_te = bool(int((os.environ.get("ENABLE_TE", False)))) + return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te) + + @staticmethod + def get_helper(): + if TransformerEngineHelper.is_enabled_te(): + return TEInstalledHelper + return TENotInstalledHelper + + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): + try: + with TransformerEngineHelper.get_helper().fp8_autocast(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis): + yield + finally: + pass diff --git a/paxml/main.py b/paxml/main.py index f03e2d4a4..c0173dfda 100644 --- a/paxml/main.py +++ b/paxml/main.py @@ -52,6 +52,7 @@ from paxml import trainer_lib from paxml import tuning_lib from paxml import ml_monitoring +from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper from praxis import pax_fiddle from praxis import py_utils @@ -519,37 +520,38 @@ def create_experiment_config(): ), ) - if FLAGS.exp is not None: - experiment_config = get_experiment(FLAGS.exp)() - elif absl_flags.fdl_flags_supplied(): - # Use the legacy Fiddle flags API to parse command line Fiddle flags. - cfg = absl_flags.create_buildable_from_flags( - module=None, allow_imports=True) - experiment_config = pax_fiddle.build(cfg) - logging.warning( - 'Legacy Fiddle flags API usage detected. Please use the new Fiddle' - ' command line flag `fdl` with various commands to specify the' - ' config and any overrides. Please see' - ' `fiddle/docs/flags_code_lab.md` for more' - ' documentation on Fiddle flags usage.' - ) - elif _FIDDLE_CONFIG.value is not None: - # This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse - # command line Fiddle flags. See - # `fiddle/docs/flags_code_lab.md` for details on the new - # Fiddle flags API. - logging.info( - 'Using pax_fiddle_config from the command line: %s', - _FIDDLE_CONFIG.value, - ) - experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value) - else: - raise app.UsageError( - 'No experiment provided. At least one of --exp, --fdl,' - ' --fdl_config, or --fdl_config_file is required.' - ) + with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'): + if FLAGS.exp is not None: + experiment_config = get_experiment(FLAGS.exp)() + elif absl_flags.fdl_flags_supplied(): + # Use the legacy Fiddle flags API to parse command line Fiddle flags. + cfg = absl_flags.create_buildable_from_flags( + module=None, allow_imports=True) + experiment_config = pax_fiddle.build(cfg) + logging.warning( + 'Legacy Fiddle flags API usage detected. Please use the new Fiddle' + ' command line flag `fdl` with various commands to specify the' + ' config and any overrides. Please see' + ' `fiddle/docs/flags_code_lab.md` for more' + ' documentation on Fiddle flags usage.' + ) + elif _FIDDLE_CONFIG.value is not None: + # This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse + # command line Fiddle flags. See + # `fiddle/docs/flags_code_lab.md` for details on the new + # Fiddle flags API. + logging.info( + 'Using pax_fiddle_config from the command line: %s', + _FIDDLE_CONFIG.value, + ) + experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value) + else: + raise app.UsageError( + 'No experiment provided. At least one of --exp, --fdl,' + ' --fdl_config, or --fdl_config_file is required.' + ) - experiment_config.validate() + experiment_config.validate() return experiment_config @@ -565,11 +567,12 @@ def _main(argv: Sequence[str]) -> None: with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.INITIALIZE_SETUP): experiment_config = create_experiment_config() - run( - experiment_config=experiment_config, - enable_checkpoint_saving=FLAGS.enable_checkpoint_saving, - startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs, - ) + with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'): + run( + experiment_config=experiment_config, + enable_checkpoint_saving=FLAGS.enable_checkpoint_saving, + startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs, + ) _TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+') diff --git a/paxml/tasks_lib.py b/paxml/tasks_lib.py index 010998dd8..e210d6372 100644 --- a/paxml/tasks_lib.py +++ b/paxml/tasks_lib.py @@ -43,6 +43,7 @@ from paxml import io_utils from paxml import learners as learners_lib from paxml import train_states +from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST from praxis import asserts from praxis import base_hyperparams from praxis import base_input @@ -1786,7 +1787,7 @@ def _apply_init_checkpoint_rule( ) # Initialize with a dummy seed var_weight_hparams = ckpt_task.model.abstract_init_with_metadata( - inputs_shape_dtype) + inputs_shape_dtype, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) ckpt_train_state = ckpt_task.create_train_state_padded_shapes( var_weight_hparams) train_state_pspecs = ckpt_task.create_train_state_partition_specs( diff --git a/paxml/trainer_lib.py b/paxml/trainer_lib.py index 6ec25e8c6..a9afc9e60 100644 --- a/paxml/trainer_lib.py +++ b/paxml/trainer_lib.py @@ -26,6 +26,7 @@ from etils import epath import fiddle as fdl from flax import struct as flax_struct +from flax.linen.fp8_ops import fm32 import jax from jax import numpy as jnp from jax.experimental import pjit @@ -35,6 +36,7 @@ from paxml import sgf from paxml import tasks_lib from paxml import train_states +from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST from praxis import asserts from praxis import base_hyperparams from praxis import base_input @@ -167,8 +169,7 @@ def create_train_state_metadata( A TrainStateMetadata instance. """ var_weight_hparams = jax_task.model.abstract_init_with_metadata( - train_shape_dtype, do_eval=do_eval - ) + train_shape_dtype, do_eval=do_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) padded_global_shapes = jax_task.create_train_state_padded_shapes( var_weight_hparams, discard_opt_states=discard_opt_states ) @@ -217,7 +218,8 @@ def write_post_init_model_hparams_file( logging.info('post_init_model_params: %s', params_fpath) job_log_dir.mkdir(parents=True, exist_ok=True) hyper_params = model.abstract_init_with_mdl_config( - train_state_metadata.input_shape_dtype, do_eval=do_eval + train_state_metadata.input_shape_dtype, do_eval=do_eval, + extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST ) with params_fpath.open('w') as params_file: hyper_params_dump = base_hyperparams.nested_struct_to_text(hyper_params) @@ -379,7 +381,8 @@ def initialize_model_state( is_eval_for_init = is_eval if not var_weight_hparams: var_weight_hparams = model.abstract_init_with_metadata( - inputs_shape_dtype, do_eval=is_eval_for_init + inputs_shape_dtype, do_eval=is_eval_for_init, + extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST ) logging.info('init_var prng_seed: %s', init_key) logging.info('var_weight_hparams: %s', var_weight_hparams) @@ -396,7 +399,7 @@ def init_fn(init_key): inputs = jax.tree.map(jnp.zeros_like, inputs_shape_dtype) if model.hparams.fprop_dtype == jnp.bfloat16: inputs = jax.tree.map(_maybe_to_bfloat16, inputs) - return model.init(init_key, inputs) + return model.init(init_key, inputs, mutable=DEFAULT_INIT_MUTABLE_LIST) initial_vars = init_fn(init_key) logging.info('initial_vars: %s', jax.tree.map(jnp.shape, initial_vars)) @@ -802,6 +805,21 @@ def _default_apply_fn( ) +def _maybe_to_fm32_vars(mdl_vars, var_weight_hparams): + asserts.assert_same_structure(mdl_vars, var_weight_hparams) + + def _maybe_fm32_var_fn(var, var_param): + if base_layer.var_overwrite_with_gradient(var_param): + return jax.lax.convert_element_type(var, fm32) + else: + return var + + is_leaf = lambda x: not isinstance(x, (tuple, dict, list)) + return jax.tree_util.tree_map( + _maybe_fm32_var_fn, mdl_vars, var_weight_hparams, is_leaf=is_leaf + ) + + class LossFnProtocol(Protocol): def __call__( @@ -809,7 +827,6 @@ def __call__( ) -> tuple[JTensor, sgf.GradAuxInfo]: """Produces losses and grad info by passing the inputs through a model.""" - def _get_default_loss_fn( jax_task: tasks_lib.SingleTask, context_p: base_layer.JaxContext.HParams, @@ -833,6 +850,8 @@ def _loss_fn( else: assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.') + mdl_vars = _maybe_to_fm32_vars(mdl_vars, var_weight_hparams) + with base_layer.JaxContext.new_context(hparams=context_p): k1, k2, k3 = jax.random.split(prng_key, 3) (metrics, per_example_output), updated_vars = apply_fn( @@ -994,6 +1013,7 @@ def get_excluded_var_masks( excluded_for_grad = tasks_lib.get_excluded_var_mask_for_grad( var_weight_hparams, learner ) + _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad) # Excluded for optimizer states. @@ -1090,7 +1110,7 @@ def train_step_single_learner( if not var_weight_hparams: with base_layer.JaxContext.new_context(hparams=context_p): - var_weight_hparams = model.abstract_init_with_metadata(inputs) + var_weight_hparams = model.abstract_init_with_metadata(inputs, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) updated_model_vars = jax_task.maybe_adjust_train_state( # pytype: disable=wrong-arg-types # jax-ndarray step=states.step, mdl_vars=states.mdl_vars, @@ -1162,6 +1182,7 @@ def train_step_single_learner( wps_with_opt = tasks_lib.filter_vars_for_grad_or_opt( var_weight_hparams, excluded_for_learner ) + transformed_grads, new_opt_states = learner.update_states( grads, states.opt_states[0], vars_with_opt, wps_with_opt ) @@ -1197,6 +1218,7 @@ def train_step_single_learner( states.mdl_vars, mdl_vars, ) + new_states = states.new_state( mdl_vars=mdl_vars, opt_states=[new_opt_states], extra_state=() ) @@ -1300,7 +1322,7 @@ def eval_step_single_learner( var_weight_hparams = model.abstract_init_with_metadata( inputs, do_eval=not jax_task.hparams.train.always_use_train_for_model_init, - ) + extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) if fprop_dtype == jnp.float32: pass @@ -1554,7 +1576,7 @@ def initialize_partitioned_model_states( model = jax_task.model if not var_weight_hparams: var_weight_hparams = model.abstract_init_with_metadata( - global_input_shapes, do_eval=is_eval + global_input_shapes, do_eval=is_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST ) train_state_partition_specs = (