Skip to content

Commit

Permalink
Update checkpointing test for async
Browse files Browse the repository at this point in the history
  • Loading branch information
anfals committed Jan 18, 2025
1 parent 1c47a6d commit b96e814
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions MaxText/tests/integration_tests/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run_checkpointing(attention_type):
"False", # collect_stack_trace
"grain", # dataset_type
attention_type,
"False", # async_checkpointing"
]

subprocess.run(command, check=True, cwd="..")
Expand Down
5 changes: 3 additions & 2 deletions end_to_end/test_checkpointing.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ATTENTION=${6}
if [ -z "${6}" ]; then
ATTENTION='autoselected'
fi
ASYNC_CHECKPOINTING=${7:-true}
eval_metrics=checkpoint_save_restore
model_params=" base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128"
CMD_DATA=""
Expand All @@ -37,13 +38,13 @@ fi
#Train
CMD1="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\
metrics_file=saved_metrics.txt checkpoint_period=3 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\
async_checkpointing=false collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION"
async_checkpointing=$ASYNC_CHECKPOINTING collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION"
CMD1+=$model_params
CMD1+=$CMD_DATA

CMD2="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\
metrics_file=restored_metrics.txt base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\
async_checkpointing=false collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION"
async_checkpointing=$ASYNC_CHECKPOINTING collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION"
CMD2+=$model_params
CMD2+=$CMD_DATA

Expand Down

0 comments on commit b96e814

Please sign in to comment.