12
12
13
13
set -ex
14
14
MODEL_VARIATION=' 8x22b'
15
- PREDICT_LEN=7
16
- ATOL=60.0
17
- RTOL=10.0
18
15
19
16
if [ -z " ${BASE_OUTPUT_PATH} " ]; then
20
17
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
@@ -29,65 +26,7 @@ export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/
29
26
30
27
export TOKENIZER_PATH=assets/tokenizer.mistral-v3
31
28
32
- # Run decoding with converted ckpt - matmul implementation
33
- python3 MaxText/decode.py MaxText/configs/base.yml \
34
- load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding \
35
- per_device_batch_size=1 model_name=mixtral-8x22b async_checkpointing=false \
36
- tokenizer_path=${TOKENIZER_PATH} ici_tensor_parallelism=1 \
37
- ici_fsdp_parallelism=-1 max_prefill_predict_length=64 max_target_length=64 \
38
- prompt=" [INST] I love to [/INST]" megablox=False weight_dtype=float16
39
-
40
- # TODO(rdyro): add decoding test for megablox implementation
41
- # python3 MaxText/decode.py MaxText/configs/base.yml \
42
- # load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding \
43
- # per_device_batch_size=1 model_name=mixtral-8x22b async_checkpointing=false \
44
- # tokenizer_path=${TOKENIZER_PATH} ici_tensor_parallelism=1 \
45
- # ici_fsdp_parallelism=-1 max_prefill_predict_length=16 max_target_length=24 \
46
- # prompt="[INST] I love to [/INST]" megablox=True weight_dtype=float16
47
-
48
- # Test whether the forward pass logits match the golden logits - matmul implementation
49
- python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml \
50
- base_output_directory=${BASE_OUTPUT_PATH} \
51
- load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test \
52
- per_device_batch_size=1 model_name=mixtral-8x22b \
53
- tokenizer_path=${TOKENIZER_PATH} ici_tensor_parallelism=1 \
54
- ici_fsdp_parallelism=-1 max_prefill_predict_length=$PREDICT_LEN max_target_length=$PREDICT_LEN \
55
- dataset_type=synthetic dtype=bfloat16 weight_dtype=float16 megablox=False \
56
- --atol=$ATOL --rtol=$RTOL --token_size=$PREDICT_LEN
57
- # TODO(rdyro): figure out the reason for numerical mismatch for some tokens
58
-
59
- python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml \
60
- base_output_directory=${BASE_OUTPUT_PATH} \
61
- load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test \
62
- per_device_batch_size=1 model_name=mixtral-8x22b \
63
- tokenizer_path=${TOKENIZER_PATH} ici_tensor_parallelism=1 \
64
- ici_fsdp_parallelism=-1 max_prefill_predict_length=$PREDICT_LEN max_target_length=$PREDICT_LEN \
65
- dataset_type=synthetic dtype=bfloat16 weight_dtype=bfloat16 megablox=False \
66
- --atol=$ATOL --rtol=$RTOL --token_size=$PREDICT_LEN
67
- # TODO(rdyro): figure out the reason for numerical mismatch for some tokens
68
-
69
- # Test whether the forward pass logits match the golden logits - megablox implementation
70
- python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml \
71
- base_output_directory=${BASE_OUTPUT_PATH} \
72
- load_parameters_path=${SCANNED_CHECKPOINT} run_name=megablox_forward_pass_test \
73
- per_device_batch_size=1 model_name=mixtral-8x22b \
74
- tokenizer_path=${TOKENIZER_PATH} ici_tensor_parallelism=1 \
75
- ici_fsdp_parallelism=-1 max_prefill_predict_length=$PREDICT_LEN max_target_length=$PREDICT_LEN \
76
- dataset_type=synthetic dtype=bfloat16 weight_dtype=bfloat16 megablox=True \
77
- --atol=$ATOL --rtol=$RTOL --token_size=$PREDICT_LEN
78
- # TODO(rdyro): figure out the reason for numerical mismatch for some tokens
79
-
80
- python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml \
81
- base_output_directory=${BASE_OUTPUT_PATH} \
82
- load_parameters_path=${SCANNED_CHECKPOINT} run_name=megablox_forward_pass_test \
83
- per_device_batch_size=1 model_name=mixtral-8x22b \
84
- tokenizer_path=${TOKENIZER_PATH} ici_tensor_parallelism=1 \
85
- ici_fsdp_parallelism=-1 max_prefill_predict_length=$PREDICT_LEN max_target_length=$PREDICT_LEN \
86
- dataset_type=synthetic dtype=bfloat16 weight_dtype=float16 megablox=True \
87
- --atol=$ATOL --rtol=$RTOL --token_size=$PREDICT_LEN
88
- # TODO(rdyro): figure out the reason for numerical mismatch for some tokens
89
-
90
- # training
29
+ # TODO(ranran): enable the fine-tuning, decoding, and forward_pass_logit_checker tests once b/380148614 has been fixed
91
30
92
31
# Run pre-training without load_parameters_path - megablox implementation
93
32
python3 MaxText/train.py MaxText/configs/base.yml \
@@ -97,12 +36,3 @@ python3 MaxText/train.py MaxText/configs/base.yml \
97
36
steps=5 max_target_length=1024 async_checkpointing=false \
98
37
tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 \
99
38
weight_dtype=bfloat16 megablox=True
100
-
101
- # Run fine-tuning - megablox implementation
102
- python3 MaxText/train.py MaxText/configs/base.yml \
103
- base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \
104
- load_parameters_path=${SCANNED_CHECKPOINT} run_name=fine_tuning \
105
- per_device_batch_size=1 model_name=mixtral-8x22b ici_tensor_parallelism=1 \
106
- ici_fsdp_parallelism=-1 steps=10 max_target_length=1024 \
107
- async_checkpointing=false tokenizer_path=${TOKENIZER_PATH} checkpoint_period=100 \
108
- attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False
0 commit comments