Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable c4/en, c4/multilingual-en dataset and convergence configs for llama convergence tests #1080

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ dataset_type: tfds
dataset_path: "" # your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"
dataset_name: 'c4/en:3.0.1'
eval_dataset_name: 'c4/en:3.0.1'
train_split: 'train'
eval_split: 'validation'
# for HuggingFace input pipeline (dataset_type=hf)
hf_path: ''
Expand Down
2 changes: 1 addition & 1 deletion MaxText/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def make_tfds_train_iterator(
"""load dataset, preprocess and return iterators"""
train_ds = get_datasets(
dataset_name=config.dataset_name,
data_split="train",
data_split=config.train_split,
shuffle_files=config.enable_data_shuffling,
shuffle_seed=config.data_shuffle_seed,
dataloading_host_index=process_indices_train.index(jax.process_index()),
Expand Down
3 changes: 1 addition & 2 deletions MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ def preprocess_train_dataset(
train_ds = train_ds.map(
lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)), num_parallel_calls=AUTOTUNE
)

train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096)
train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=max_target_length * 2)
train_ds = split_tokens_to_targets_length(train_ds, max_target_length)
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)
train_ds = sequence_packing.pack_dataset(train_ds, max_target_length)
Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_first_step(state):
def load_next_batch(train_iter, example_batch, config):
"""Loads the next batch. Can keep reusing the same batch for performance reasons"""

if config.reuse_example_batch and example_batch is not None:
if config.reuse_example_batch > 0 and example_batch is not None:
return example_batch
else:
return next(train_iter)
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_70b_4096_real_data',
'llama3_70b_8192',
'llama3_1_405b_8192_fsdp_dcn',
'llama3_1_405b_8192_fsdp_dcn_c4',
'llama3_1_8b_8192_c4',
'llama3_70b_8192',
'mixtral_8x7b_dropped',
'mixtral_8x7b_dropped_int8',
'mixtral_8x7b_dropless',
Expand Down
270 changes: 270 additions & 0 deletions benchmarks/convergence/c4_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import dataclasses
from maxtext_trillium_model_configs import MaxTextModel, DatasetHParams, ConvHParams
import xla_flags_library

c4_mlperf_hp = DatasetHParams(
name="c4mlperf",
dataset_path="gs://mlperf-exp-us-east1-cp0",
dataset_name="c4/en:3.0.7",
dataset_type="c4_mlperf",
train_split="train",
eval_split="c4/en:3.0.5",
eval_steps=4 * 512,
add_bos=False,
add_eos=False,
tokenizer_path="gs://mlperf-llm-public2/vocab/c4_en_301_5Mexp2_spm.model")

c4_en_hp = DatasetHParams(
name="c4en",
dataset_path="gs://maxtext-dataset",
dataset_name="c4/en:3.0.1",
dataset_type="tfds",
train_split="train",
eval_split="validation",
eval_steps=36 * 512,
add_bos=False,
add_eos=False,
tokenizer_path="gs://mlperf-llm-public2/vocab/c4_en_301_5Mexp2_spm.model")

c4_mutil_hp = DatasetHParams(
name="c4multien",
dataset_path="gs://mlperf-llm-public2",
dataset_name="c4/multilingual:3.1.0",
dataset_type="tfds",
train_split="en",
eval_split="en-validation",
eval_steps=206 * 2048, # 852 * 512
add_bos=False,
add_eos=False,
tokenizer_path="gs://mlperf-llm-public2/vocab/c4_en_301_5Mexp2_spm.model")

llama3_1_8b_8192_c4en = MaxTextModel(
model_name="llama3_1_8b_8192_c4en",
model_type="llama3.1-8b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": -1,
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 8192,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": c4_en_hp.dataset_path,
"dataset_name": c4_en_hp.dataset_name,
"dataset_type": c4_en_hp.dataset_type,
"tokenizer_path": c4_en_hp.tokenizer_path,
"train_split": c4_en_hp.train_split,
"eval_split": c4_en_hp.eval_split,
"add_bos": c4_en_hp.add_bos,
"add_eos": c4_en_hp.add_eos,
"enable_checkpointing": True,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000,
"eval_interval": 100,
"eval_steps": c4_en_hp.eval_steps,
"data_shuffle_seed": 1238,
"checkpoint_period": 1000
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama3_1_8b_8192_c4multien = MaxTextModel(
model_name="llama3_1_8b_8192_c4multien",
model_type="llama3.1-8b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": -1,
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 8192,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": c4_mutil_hp.dataset_path,
"dataset_name": c4_mutil_hp.dataset_name,
"dataset_type": c4_mutil_hp.dataset_type,
"eval_dataset_name": c4_mutil_hp.dataset_name,
"tokenizer_path": c4_mutil_hp.tokenizer_path,
"train_split": c4_mutil_hp.train_split,
"eval_split": c4_mutil_hp.eval_split,
"add_bos": c4_mutil_hp.add_bos,
"add_eos": c4_mutil_hp.add_eos,
"enable_checkpointing": True,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000,
"eval_interval": 100,
"eval_steps": c4_mutil_hp.eval_steps,
"data_shuffle_seed": 1238,
"checkpoint_period": 1000
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama3_1_8b_8192_c4_mlperf = MaxTextModel(
model_name="llama3_1_8b_8192_c4_mlperf",
model_type="llama3.1-8b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": -1,
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 8192,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://mlperf-exp-us-east1-cp0",
"dataset_name": "c4/en:3.0.7",
"dataset_type": "c4_mlperf",
"tokenizer_path": (
"gs://mlperf-llm-public2/vocab/c4_en_301_5Mexp2_spm.model"
),
"eval_dataset_name": "c4/en:3.0.5",
"add_bos": False,
"add_eos": False,
"enable_checkpointing": True,
"checkpoint_period": 2000,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"learning_rate": 3e-4,
"warmup_steps_fraction": 0.1,
"steps": 1000,
"eval_interval": 100,
"data_shuffle_seed": 1238,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama3_70b_8192_c4multien = MaxTextModel(
model_name="llama3-70b-8192",
model_type="llama3-70b",
tuning_params={
"per_device_batch_size": 3,
"ici_fsdp_parallelism": -1,
"remat_policy": "full",
"max_target_length": 8192,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": c4_mutil_hp.dataset_path,
"dataset_name": c4_mutil_hp.dataset_name,
"dataset_type": c4_mutil_hp.dataset_type,
"eval_dataset_name": c4_mutil_hp.dataset_name,
"tokenizer_path": c4_mutil_hp.tokenizer_path,
"train_split": c4_mutil_hp.train_split,
"eval_split": c4_mutil_hp.eval_split,
"add_bos": c4_mutil_hp.add_bos,
"add_eos": c4_mutil_hp.add_eos,
"enable_checkpointing": True,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000,
"eval_interval": 100,
"eval_steps": c4_mutil_hp.eval_steps,
"data_shuffle_seed": 1238,
"checkpoint_period": 1000
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama3_1_405b_8192_fsdp_dcn_c4 = MaxTextModel(
model_name="llama3-1-405b-8192-fsdp-dcn",
model_type="llama3.1-405b",
tuning_params={
"per_device_batch_size": 1,
"ici_fsdp_parallelism": 64,
"ici_tensor_parallelism": 4,
"dcn_fsdp_parallelism": 2,
"allow_split_physical_axes": True,
"custom_mesh": "hybrid_ring_64x4",
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"out_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://mlperf-exp-us-east1-cp0",
"dataset_name": "c4/en:3.0.7",
"dataset_type": "c4_mlperf",
"tokenizer_path": (
"gs://mlperf-llm-public2/vocab/c4_en_301_5Mexp2_spm.model"
),
"enable_checkpointing": True,
"checkpoint_period": 2000,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"learning_rate": 1.25e-5,
"warmup_steps_fraction": 0.5
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)

import math

def setupDataset(model: MaxTextModel, params: DatasetHParams):
model.model_name = model.model_name + "_" + params.name
model.tuning_params["reuse_example_batch"] = -1
model.tuning_params["dataset_path"] = params.dataset_path
model.tuning_params["dataset_name"] = params.dataset_name
model.tuning_params["dataset_type"] = params.dataset_type
model.tuning_params["eval_dataset_name"] = params.dataset_name
model.tuning_params["tokenizer_path"] = params.tokenizer_path
model.tuning_params["train_split"] = params.train_split
model.tuning_params["eval_split"] = params.eval_split
model.tuning_params["add_bos"] = params.add_bos
model.tuning_params["add_eos"] = params.add_eos
model.tuning_params["eval_steps"] = params.eval_steps
model.tuning_params["data_shuffle_seed"] = 1238


def setupC4Multilingualen(model: MaxTextModel):
setupDataset(model, c4_mutil_hp)

def setupC4En(model: MaxTextModel):
setupDataset(model, c4_en_hp)

def setupC4Mlperf(model: MaxTextModel):
setupDataset(model, c4_mlperf_hp)

def setupConvHParams(model: MaxTextModel, params: ConvHParams, num_devices: int):
gbs = params.global_batch_size
total_steps = params.total_tokens_to_train / gbs
model.tuning_params["per_device_batch_size"] = float(gbs / num_devices)
model.tuning_params["learning_rate"] = params.learning_rate
model.tuning_params["warmup_steps_fraction"] = float(params.warmup_samples) / gbs / total_steps
model.tuning_params["learning_rate_schedule_steps"] = int(params.decay_end_samples / gbs)
model.tuning_params["steps"] = int(total_steps)
eval_samples = model.tuning_params["eval_steps"]
model.tuning_params["eval_steps"] = int(math.floor(eval_samples / gbs))
model.tuning_params["eval_interval"]= int(math.ceil(params.eval_interval / gbs))
model.tuning_params["enable_checkpointing"] = True
model.tuning_params["checkpoint_period"] = int(math.ceil( 1000 * 512 / gbs))
Loading
Loading