Skip to content

Commit e68c56d

Browse files
author
maxtext authors
committed
Merge pull request #1069 from AI-Hypercomputer:fix_moe_test
PiperOrigin-RevId: 700849430
2 parents 7331e13 + b1c4730 commit e68c56d

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,21 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned_ckpt/checkpoints/0/item
3131
# TODO(ranran): add decoding test for megablox implementation
3232
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=24 prompt="[INST] I love to [/INST]" megablox=False scan_layers=false
3333

34+
# Run decoding with converted ckpt - dropping implementation
35+
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=24 prompt="[INST] I love to [/INST]" megablox=False scan_layers=false capacity_factor=1.25
36+
3437
# Test whether the forward pass logits match the golden logits - matmul implementation
35-
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=8 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False scan_layers=false --atol=3 --rtol=1 --token_size=4
38+
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=4 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False scan_layers=false --atol=3 --rtol=1 --token_size=4
3639

3740
# Test whether the forward pass logits match the golden logits - megablox implementation
3841
# TODO(ranran): investigate the root cause of the excessive tolerance
39-
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=megablox_forward_pass_test per_device_batch_size=8 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_fsdp_parallelism=64 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=bfloat16 weight_dtype=bfloat16 scan_layers=false --atol=20 --rtol=10 --token_size=4
42+
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=megablox_forward_pass_test per_device_batch_size=4 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_fsdp_parallelism=64 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=bfloat16 weight_dtype=bfloat16 scan_layers=false --atol=20 --rtol=10 --token_size=4
43+
44+
# Run pre-training - megablox implementation
45+
python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=megablox_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16
4046

41-
# Run fine-tuning - megablox implementation
42-
python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=fine_tuning per_device_batch_size=8 model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=10 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 checkpoint_period=5 attention=flash dtype=bfloat16 weight_dtype=bfloat16
47+
# Run pre-training - matmul implementation
48+
python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False
4349

44-
# Run pre-training without load_parameters_path - megablox implementation
45-
python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=pre_training per_device_batch_size=8 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16
50+
# Run pre-training - dropping implementation
51+
python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=dropping_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=64 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False capacity_factor=1

0 commit comments

Comments
 (0)