Skip to content
Open
13 changes: 8 additions & 5 deletions paxml/contrib/gpu/scripts_gpu/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paxml.contrib.gpu.scripts_gpu.tasks import LambadaDataset
from paxml.contrib.gpu.scripts_gpu.tasks import PileUnsupervisedDataset
from paxml.tasks.lm.model_params import maybe_setup_moe_params
from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper
from paxml.tasks.lm.params.c4 import TransformerLmSpmdAdam
from paxml.tasks.lm.params.lm_cloud import SyntheticDataset
from praxis import base_layer
Expand Down Expand Up @@ -116,7 +117,7 @@ class GPT126MBase(TransformerLmSpmdAdam):

MAX_SEQ_LEN = 2048
VOCAB_SIZE = 50304
PACKED_INPUT = True
PACKED_INPUT = False
PERCORE_BATCH_SIZE = 4

NUM_LAYERS = 12
Expand Down Expand Up @@ -171,10 +172,12 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
fdl.get_callable(stacked_p), transformers.StackedTransformerRepeated
):
stacked_p = stacked_p.block
transformer_layer_p = stacked_p.transformer_layer_params_tpl
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True

task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True
if not TransformerEngineHelper.is_enabled_te():
transformer_layer_p = stacked_p.transformer_layer_params_tpl
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True

model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
Expand Down Expand Up @@ -239,7 +242,7 @@ class GPT175BBase(GPT126MBase):
# Known as MLP_DIM in t5x
HIDDEN_DIMS = MODEL_DIMS * 4
# Defaults to MODEL_DIMS // NUM_HEADS.
DIMS_PER_HEAD = None
DIMS_PER_HEAD = 128
# Known as NUM_EMBEDDINGS in t5x
VOCAB_SIZE = 50257
USE_REPEATED_LAYER = True
Expand Down
76 changes: 76 additions & 0 deletions paxml/contrib/gpu/scripts_gpu/te_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from contextlib import contextmanager

from praxis import base_layer

try:
import transformer_engine.jax as te
from transformer_engine.common import recipe
_IS_TRANSFORMER_ENGINE_INSTALLED = True
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME]

except ModuleNotFoundError as e:
_IS_TRANSFORMER_ENGINE_INSTALLED = False
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST


class TransformerEngineHelperBase:

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
raise NotImplementedError


class TENotInstalledHelper(TransformerEngineHelperBase):

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
try:
yield
finally:
pass


class TEInstalledHelper(TransformerEngineHelperBase):

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID,
amax_history_len=1024, amax_compute_algo='max')

enable_fp8 = bool(int((os.environ.get("ENABLE_FP8", False))))
try:
with te.fp8_autocast(enabled=enable_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(dp_resource=dp_mesh_axis,
tp_resource=tp_mesh_axis,
fsdp_resource=fsdp_mesh_axis)):
yield
finally:
pass


class TransformerEngineHelper(TransformerEngineHelperBase):

@staticmethod
def is_enabled_te():
enable_te = bool(int((os.environ.get("ENABLE_TE", False))))
return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te)

@staticmethod
def get_helper():
if TransformerEngineHelper.is_enabled_te():
return TEInstalledHelper
return TENotInstalledHelper

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
try:
with TransformerEngineHelper.get_helper().fp8_autocast(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis):
yield
finally:
pass
73 changes: 38 additions & 35 deletions paxml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from paxml import trainer_lib
from paxml import tuning_lib
from paxml import ml_monitoring
from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper
from praxis import pax_fiddle
from praxis import py_utils

Expand Down Expand Up @@ -519,37 +520,38 @@ def create_experiment_config():
),
)

if FLAGS.exp is not None:
experiment_config = get_experiment(FLAGS.exp)()
elif absl_flags.fdl_flags_supplied():
# Use the legacy Fiddle flags API to parse command line Fiddle flags.
cfg = absl_flags.create_buildable_from_flags(
module=None, allow_imports=True)
experiment_config = pax_fiddle.build(cfg)
logging.warning(
'Legacy Fiddle flags API usage detected. Please use the new Fiddle'
' command line flag `fdl` with various commands to specify the'
' config and any overrides. Please see'
' `fiddle/docs/flags_code_lab.md` for more'
' documentation on Fiddle flags usage.'
)
elif _FIDDLE_CONFIG.value is not None:
# This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse
# command line Fiddle flags. See
# `fiddle/docs/flags_code_lab.md` for details on the new
# Fiddle flags API.
logging.info(
'Using pax_fiddle_config from the command line: %s',
_FIDDLE_CONFIG.value,
)
experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value)
else:
raise app.UsageError(
'No experiment provided. At least one of --exp, --fdl,'
' --fdl_config, or --fdl_config_file is required.'
)
with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'):
if FLAGS.exp is not None:
experiment_config = get_experiment(FLAGS.exp)()
elif absl_flags.fdl_flags_supplied():
# Use the legacy Fiddle flags API to parse command line Fiddle flags.
cfg = absl_flags.create_buildable_from_flags(
module=None, allow_imports=True)
experiment_config = pax_fiddle.build(cfg)
logging.warning(
'Legacy Fiddle flags API usage detected. Please use the new Fiddle'
' command line flag `fdl` with various commands to specify the'
' config and any overrides. Please see'
' `fiddle/docs/flags_code_lab.md` for more'
' documentation on Fiddle flags usage.'
)
elif _FIDDLE_CONFIG.value is not None:
# This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse
# command line Fiddle flags. See
# `fiddle/docs/flags_code_lab.md` for details on the new
# Fiddle flags API.
logging.info(
'Using pax_fiddle_config from the command line: %s',
_FIDDLE_CONFIG.value,
)
experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value)
else:
raise app.UsageError(
'No experiment provided. At least one of --exp, --fdl,'
' --fdl_config, or --fdl_config_file is required.'
)

experiment_config.validate()
experiment_config.validate()
return experiment_config


Expand All @@ -565,11 +567,12 @@ def _main(argv: Sequence[str]) -> None:
with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.INITIALIZE_SETUP):
experiment_config = create_experiment_config()

run(
experiment_config=experiment_config,
enable_checkpoint_saving=FLAGS.enable_checkpoint_saving,
startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs,
)
with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'):
run(
experiment_config=experiment_config,
enable_checkpoint_saving=FLAGS.enable_checkpoint_saving,
startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs,
)


_TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+')
Expand Down
3 changes: 2 additions & 1 deletion paxml/tasks_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from paxml import io_utils
from paxml import learners as learners_lib
from paxml import train_states
from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST
from praxis import asserts
from praxis import base_hyperparams
from praxis import base_input
Expand Down Expand Up @@ -1786,7 +1787,7 @@ def _apply_init_checkpoint_rule(
)
# Initialize with a dummy seed
var_weight_hparams = ckpt_task.model.abstract_init_with_metadata(
inputs_shape_dtype)
inputs_shape_dtype, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)
ckpt_train_state = ckpt_task.create_train_state_padded_shapes(
var_weight_hparams)
train_state_pspecs = ckpt_task.create_train_state_partition_specs(
Expand Down
40 changes: 31 additions & 9 deletions paxml/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from etils import epath
import fiddle as fdl
from flax import struct as flax_struct
from flax.linen.fp8_ops import fm32
import jax
from jax import numpy as jnp
from jax.experimental import pjit
Expand All @@ -35,6 +36,7 @@
from paxml import sgf
from paxml import tasks_lib
from paxml import train_states
from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST
from praxis import asserts
from praxis import base_hyperparams
from praxis import base_input
Expand Down Expand Up @@ -167,8 +169,7 @@ def create_train_state_metadata(
A TrainStateMetadata instance.
"""
var_weight_hparams = jax_task.model.abstract_init_with_metadata(
train_shape_dtype, do_eval=do_eval
)
train_shape_dtype, do_eval=do_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)
padded_global_shapes = jax_task.create_train_state_padded_shapes(
var_weight_hparams, discard_opt_states=discard_opt_states
)
Expand Down Expand Up @@ -217,7 +218,8 @@ def write_post_init_model_hparams_file(
logging.info('post_init_model_params: %s', params_fpath)
job_log_dir.mkdir(parents=True, exist_ok=True)
hyper_params = model.abstract_init_with_mdl_config(
train_state_metadata.input_shape_dtype, do_eval=do_eval
train_state_metadata.input_shape_dtype, do_eval=do_eval,
extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST
)
with params_fpath.open('w') as params_file:
hyper_params_dump = base_hyperparams.nested_struct_to_text(hyper_params)
Expand Down Expand Up @@ -379,7 +381,8 @@ def initialize_model_state(
is_eval_for_init = is_eval
if not var_weight_hparams:
var_weight_hparams = model.abstract_init_with_metadata(
inputs_shape_dtype, do_eval=is_eval_for_init
inputs_shape_dtype, do_eval=is_eval_for_init,
extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST
)
logging.info('init_var prng_seed: %s', init_key)
logging.info('var_weight_hparams: %s', var_weight_hparams)
Expand All @@ -396,7 +399,7 @@ def init_fn(init_key):
inputs = jax.tree.map(jnp.zeros_like, inputs_shape_dtype)
if model.hparams.fprop_dtype == jnp.bfloat16:
inputs = jax.tree.map(_maybe_to_bfloat16, inputs)
return model.init(init_key, inputs)
return model.init(init_key, inputs, mutable=DEFAULT_INIT_MUTABLE_LIST)

initial_vars = init_fn(init_key)
logging.info('initial_vars: %s', jax.tree.map(jnp.shape, initial_vars))
Expand Down Expand Up @@ -802,14 +805,28 @@ def _default_apply_fn(
)


def _maybe_to_fm32_vars(mdl_vars, var_weight_hparams):
asserts.assert_same_structure(mdl_vars, var_weight_hparams)

def _maybe_fm32_var_fn(var, var_param):
if base_layer.var_overwrite_with_gradient(var_param):
return jax.lax.convert_element_type(var, fm32)
else:
return var

is_leaf = lambda x: not isinstance(x, (tuple, dict, list))
return jax.tree_util.tree_map(
_maybe_fm32_var_fn, mdl_vars, var_weight_hparams, is_leaf=is_leaf
)


class LossFnProtocol(Protocol):

def __call__(
self, mdl_vars: NestedJTensor, inputs: NestedMap, prng_key: PRNGKey
) -> tuple[JTensor, sgf.GradAuxInfo]:
"""Produces losses and grad info by passing the inputs through a model."""


def _get_default_loss_fn(
jax_task: tasks_lib.SingleTask,
context_p: base_layer.JaxContext.HParams,
Expand All @@ -833,6 +850,8 @@ def _loss_fn(
else:
assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')

mdl_vars = _maybe_to_fm32_vars(mdl_vars, var_weight_hparams)

with base_layer.JaxContext.new_context(hparams=context_p):
k1, k2, k3 = jax.random.split(prng_key, 3)
(metrics, per_example_output), updated_vars = apply_fn(
Expand Down Expand Up @@ -994,6 +1013,7 @@ def get_excluded_var_masks(
excluded_for_grad = tasks_lib.get_excluded_var_mask_for_grad(
var_weight_hparams, learner
)

_log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad)

# Excluded for optimizer states.
Expand Down Expand Up @@ -1090,7 +1110,7 @@ def train_step_single_learner(

if not var_weight_hparams:
with base_layer.JaxContext.new_context(hparams=context_p):
var_weight_hparams = model.abstract_init_with_metadata(inputs)
var_weight_hparams = model.abstract_init_with_metadata(inputs, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)
updated_model_vars = jax_task.maybe_adjust_train_state( # pytype: disable=wrong-arg-types # jax-ndarray
step=states.step,
mdl_vars=states.mdl_vars,
Expand Down Expand Up @@ -1162,6 +1182,7 @@ def train_step_single_learner(
wps_with_opt = tasks_lib.filter_vars_for_grad_or_opt(
var_weight_hparams, excluded_for_learner
)

transformed_grads, new_opt_states = learner.update_states(
grads, states.opt_states[0], vars_with_opt, wps_with_opt
)
Expand Down Expand Up @@ -1197,6 +1218,7 @@ def train_step_single_learner(
states.mdl_vars,
mdl_vars,
)

new_states = states.new_state(
mdl_vars=mdl_vars, opt_states=[new_opt_states], extra_state=()
)
Expand Down Expand Up @@ -1300,7 +1322,7 @@ def eval_step_single_learner(
var_weight_hparams = model.abstract_init_with_metadata(
inputs,
do_eval=not jax_task.hparams.train.always_use_train_for_model_init,
)
extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)

if fprop_dtype == jnp.float32:
pass
Expand Down Expand Up @@ -1554,7 +1576,7 @@ def initialize_partitioned_model_states(
model = jax_task.model
if not var_weight_hparams:
var_weight_hparams = model.abstract_init_with_metadata(
global_input_shapes, do_eval=is_eval
global_input_shapes, do_eval=is_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST
)

train_state_partition_specs = (
Expand Down