Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] equivalent mixtral mlperf data pipeline #1157

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import numpy as np
import orbax.checkpoint as ocp
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager

# pylint: disable=too-many-positional-arguments

Expand Down Expand Up @@ -118,6 +117,7 @@ def create_orbax_emergency_replicator_checkpoint_manager(
save_interval_steps: int,
global_mesh: jax.sharding.Mesh,
):
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
"""Returns an emergency replicator checkpoint manager."""
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
max_logging.log("Creating emergency replicator checkpoint manager...")
Expand Down
36 changes: 36 additions & 0 deletions MaxText/configs/models/mixtral-8x22b-mlperf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for mixtral-8x22b-v0.1

base_emb_dim: 6144
base_num_query_heads: 48
base_num_kv_heads: 8
base_mlp_dim: 16384
base_num_decoder_layers: 56
head_dim: 128
mlp_activations: ["silu","linear"]
# mistralai/Mixtral-8x22B-v0.1: vocab_size 32000
# mistralai/Mixtral-8x22B-Instruct-v0.1: vocab 32768
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts: 8
num_experts_per_tok: 2
rope_max_timescale: 1_000_000
decoder_block: "mistral"
dataset_path: "gs://mlperf-llm-public2"
dataset_name: "c4/en:3.0.4"
eval_dataset_name: "c4/en:3.0.4"
54 changes: 44 additions & 10 deletions MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,23 @@ def preprocess_train_dataset(

def preprocess_eval_dataset(
eval_ds: tf.data.Dataset,
sp_tokenizer,
eval_global_batch_size_to_load: int,
max_target_length: int,
num_examples: Optional[int] = None,
is_tokenized_dataset: bool = True,
) -> tf.data.Dataset:
"""Preprocess the evaluation dataset."""
# group text up to max_target_length if the dataset is not pre-tokenized/pre-processed
if not is_tokenized_dataset:
eval_ds = eval_ds.map(
lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)), num_parallel_calls=AUTOTUNE
)
# hardcode batch_sizes 24567 i.e. the exp size in split validation_24567exp
# to avoid padding tokens inserted in group text
eval_ds = reduce_concat_tokens(eval_ds, feature_key="targets", batch_size=24567)
eval_ds = split_tokens_to_targets_length(eval_ds, max_target_length)

eval_ds = sequence_packing.pack_dataset(eval_ds, max_target_length)

eval_ds = eval_ds.map(format_fn, num_parallel_calls=AUTOTUNE)
Expand Down Expand Up @@ -327,21 +339,43 @@ def make_c4_mlperf_eval_iterator(
process_indices,
):
"""Make eval iterator of customized C4 dataset for mlperf gpt3 training."""
eval_ds = get_dataset(
dataset_name=config.eval_dataset_name,
split="validation_tokenized_5662seqs",
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
enable_data_shuffling=False,
)
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"})
if config.eval_dataset_name == "c4/en:3.0.5":
is_tokenized_dataset = True
elif config.eval_dataset_name == "c4/en:3.0.4":
is_tokenized_dataset = False
else:
raise ValueError(f"{config.eval_dataset_name=} should be one of ('c4/en:3.0.4', 'c4/en:3.0.5')")
if is_tokenized_dataset:
eval_ds = get_dataset(
dataset_name=config.eval_dataset_name,
split="validation_tokenized_5662seqs",
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
enable_data_shuffling=False,
)
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"})
else:
eval_ds = get_dataset(
dataset_name=config.eval_dataset_name,
split="validation_24567exp",
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
enable_data_shuffling=False,
)

eval_ds = rekey(eval_ds, {"inputs": None, "targets": "text"})

sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos)


eval_ds = preprocess_eval_dataset(
eval_ds,
sp_tokenizer,
eval_global_batch_size_to_load=config.global_batch_size_to_load_eval,
max_target_length=config.max_target_length,
is_tokenized_dataset=is_tokenized_dataset,
)

eval_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(eval_ds, global_mesh)
Expand Down
1 change: 1 addition & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def validate_model_name(s: str) -> bool:
"mistral-7b",
"mixtral-8x7b",
"mixtral-8x22b",
"mixtral-8x22b-mlperf",
"gemma-7b",
"gemma-2b",
"gemma2-2b",
Expand Down
37 changes: 37 additions & 0 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,43 @@ def train_loop(config, state=None):
static_argnums=static_argnums_eval,
donate_argnums=donate_argnums_eval,
)
cumulative_eval_metrics = {
"scalar": {
"eval/total_loss": 0.0,
"eval/total_weights": 0.0,
"eval/avg_loss": 0.0,
"eval/moe_lb_loss": 0.0,
}
}
eval_step_count = 0
# pylint: disable=not-callable
eval_dpo_reward_accuracy = 0.0
eval_step_count = 0
# pylint: disable=not-callable
for eval_batch in eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, init_rng)
cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"])
cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"])
cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"])
eval_dpo_reward_accuracy += float(eval_metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) # for dpo only
max_logging.log(f"Completed eval step {eval_step_count}")
eval_step_count += 1
eval_loss = (
cumulative_eval_metrics["scalar"]["eval/total_loss"]
/ (cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS)
+ cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count
)
cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss
if config.use_dpo:
cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = eval_dpo_reward_accuracy / eval_step_count
max_logging.log(
f"average loss before training: {eval_step_count=}, {eval_loss=},"
f" total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}"
)

else:
p_eval_step = None

Expand Down
Loading
Loading