Skip to content

Commit

Permalink
[JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls (
Browse files Browse the repository at this point in the history
…#472)

* Refactor sharding.py for the further custom_partitioning migration

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* WAR to LN/RMSN_fp8 before migrating to CP.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Fix the wrong order of parameters of bwd of LN/RMSN_fp8.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Following review feedback to modify

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Force the hidden dim in Norm ops to no sharding and add warning msg.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Reuse fwd_rule in VJP functions

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* add gelu and dgelu.

Signed-off-by: Ming-Xu Huang <[email protected]>

* Reuse fwd_rule in VJP functions for attentions

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Apply native FP8 Dtypes to fp8.py

Signed-off-by: Ming-Xu Huang <[email protected]>

* Migrating cast_and_transpose from xmap to custom_partitioning

Signed-off-by: Ming-Xu Huang <[email protected]>

* Migrating transpose from xmap to custom_partitioning

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Apply XLA pattern match to perform FP8 GEMM.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* migrate layernorm_fp8 to custom_partitioning.

Signed-off-by: Ming-Xu Huang <[email protected]>

* Unify code style

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Extend supported of Transpose with FP8

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Implementing layernorm_fp8_dot based on migrated custom calls.

Signed-off-by: Ming-Xu Huang <[email protected]>

* Renaming variables and publish NVTE_FP8_COLLECTION_NAME

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Replace Q/DQ custom calls with native XLA implementations

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* migrate gelu_fp to custom_partitioning.

Signed-off-by: Ming-Xu Huang <[email protected]>

* Miner fix

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Support custom calls with mutli-dims

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Support gerneral dot indices in _fp8_dot_impl

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Implementing layernrom_geglu_fp8_mlp

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Remove GEMM custom calls

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Remove xmap related code

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Fix typo and add query-function to FP8MetaPackage

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Fix some bugs of custom calls

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Fix CT's bugs

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Update UTs/eaxmaples to adapt to the API changes.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Unify kernel initilization in MLP.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Modifing with code review's feedback

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Update README and Add deprecating warning to *ShardingType

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>

* Canonicalize the dtype

Signed-off-by: Ming Huang <[email protected]>

* Adding assertion for non-supported batch dims.

Signed-off-by: Ming Huang <[email protected]>

* Adding doc/examples to _multidim_transpose

Signed-off-by: Ming Huang <[email protected]>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.

Signed-off-by: Ming Huang <[email protected]>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.

Signed-off-by: Ming Huang <[email protected]>

* Apply dtype-based rtol/atol to UTs

Signed-off-by: Ming Huang <[email protected]>

* Deprecate QKV_INTERLEAVED enum

Signed-off-by: Ming Huang <[email protected]>

* Skip test_distributed_custom_ops.py

Signed-off-by: Ming Huang <[email protected]>

* Fix the wrong sharding of bias in SelfAttn

Signed-off-by: Ming Huang <[email protected]>

* WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled

Signed-off-by: Ming Huang <[email protected]>

* Adding distributed ops unit-tests

Signed-off-by: Ming Huang <[email protected]>

* Adding license to test_distributed_*

Signed-off-by: Ming Huang <[email protected]>

* Follow review feedback to modify

Signed-off-by: Ming Huang <[email protected]>

* Use total bytes involved in collective ops as criteria.

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>
Co-authored-by: Donglin Yang <[email protected]>
  • Loading branch information
mingxu1067 and Young768 authored Nov 14, 2023
1 parent 7976bd0 commit 71e51ea
Show file tree
Hide file tree
Showing 30 changed files with 4,584 additions and 5,629 deletions.
2 changes: 0 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ Flax
for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
# Update FP8 metas
other_variables = te.update_fp8_metas(other_grads)
.. overview-end-marker-do-not-remove
Expand Down
25 changes: 10 additions & 15 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,18 @@ def __call__(self, x, mask, disable_dropout=False):
x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x)

x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x)

x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x


def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""

def loss_fn(var_collect, disable_dropout=False):
Expand All @@ -87,13 +85,11 @@ def loss_fn(var_collect, disable_dropout=False):

var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)

return state, loss, accuracy, var_collect


def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size
Expand All @@ -108,7 +104,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8)
batch_labels, var_collect, rngs)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)

Expand Down Expand Up @@ -206,9 +202,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
rngs, True))
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))


def get_params_pspec(sharding_rules, abs_var_collect):
Expand Down Expand Up @@ -269,7 +264,8 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]

with te.fp8_autocast(args.use_fp8,
sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
Expand Down Expand Up @@ -297,7 +293,7 @@ def train_and_evaluate(args):

in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,))
pjit_train_step = pjit(train_step, in_shardings, out_shardings)

in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None)
Expand All @@ -310,7 +306,7 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None

Expand All @@ -320,8 +316,7 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8,
pjit_train_step)
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step)

test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step)
Expand Down
29 changes: 12 additions & 17 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,15 @@ def __call__(self, x, mask, disable_dropout=False):

x = x.reshape(x.shape[0], -1)

x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)

x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)

x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x


def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""

def loss_fn(var_collect, disable_dropout=False):
Expand All @@ -78,13 +76,11 @@ def loss_fn(var_collect, disable_dropout=False):

var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)

return state, loss, accuracy, var_collect


def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size
Expand All @@ -99,7 +95,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8)
batch_labels, var_collect, rngs)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)

Expand Down Expand Up @@ -197,9 +193,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
rngs, True))
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))


def get_params_pspec(sharding_rules, abs_var_collect):
Expand Down Expand Up @@ -252,7 +247,8 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]

with te.fp8_autocast(args.use_fp8, sharding_resource=te.ShardingResource(DEVICE_DP_AXIS)):
with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
Expand All @@ -279,7 +275,7 @@ def train_and_evaluate(args):

in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,))
pjit_train_step = pjit(train_step, in_shardings, out_shardings)

in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None)
Expand All @@ -292,7 +288,7 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None

Expand All @@ -302,8 +298,7 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8,
pjit_train_step)
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step)

test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step)
Expand Down
28 changes: 12 additions & 16 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,11 @@ def __call__(self, x, mask, disable_dropout=False):
x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x)

x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x)

x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
Expand Down Expand Up @@ -106,7 +104,7 @@ def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False
return global_input_shape, named_sharding, inputs


def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""

def loss_fn(var_collect, disable_dropout=False):
Expand All @@ -122,14 +120,12 @@ def loss_fn(var_collect, disable_dropout=False):

var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)

return state, loss, accuracy, var_collect


def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn, mesh,
inputs_pspec, masks_pspec, labels_pspec):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn, mesh, inputs_pspec,
masks_pspec, labels_pspec):
"""Train for a single epoch."""

total_batch_size = len(train_ds['sentence'])
Expand Down Expand Up @@ -164,7 +160,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f
label_named_sharding, [batch_label])

state, loss, accuracy, var_collect = train_fn(state, shard_input, shard_mask, shard_label,
var_collect, rngs, use_fp8)
var_collect, rngs)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)

Expand Down Expand Up @@ -280,9 +276,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
rngs, True))
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))


def get_params_pspec(sharding_rules, abs_var_collect):
Expand Down Expand Up @@ -350,7 +345,8 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]

with te.fp8_autocast(args.use_fp8,
sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
Expand Down Expand Up @@ -378,7 +374,7 @@ def train_and_evaluate(args):

in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,))
pjit_train_step = pjit(train_step, in_shardings, out_shardings)

in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None)
Expand All @@ -391,7 +387,7 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
else:
for epoch in range(1, args.epochs + 1):
Expand All @@ -400,8 +396,8 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8,
pjit_train_step, shard_mesh, inputs_pspec, masks_pspec, labels_pspec)
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step,
shard_mesh, inputs_pspec, masks_pspec, labels_pspec)

test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step, shard_mesh,
Expand Down
17 changes: 7 additions & 10 deletions examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __call__(self, x, mask, disable_dropout=False):


@partial(jax.jit, static_argnums=6)
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""

def loss_fn(var_collect, disable_dropout=False):
Expand All @@ -72,13 +72,11 @@ def loss_fn(var_collect, disable_dropout=False):

var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)

return state, loss, accuracy, var_collect


def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8):
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size
Expand All @@ -93,7 +91,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8):
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8)
batch_labels, var_collect, rngs)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)

Expand Down Expand Up @@ -192,9 +190,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
rngs, True))
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))


def train_and_evaluate(args):
Expand Down Expand Up @@ -228,7 +225,7 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None

Expand All @@ -238,7 +235,7 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8)
state, train_ds, args.batch_size, rngs, var_collect)

test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)

Expand Down
Loading

0 comments on commit 71e51ea

Please sign in to comment.