-
Notifications
You must be signed in to change notification settings - Fork 295
/
64b.sh
49 lines (42 loc) · 2.01 KB
/
64b.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
echo "Running 64b.sh"
# 64B parameter model.
# This config will work out of the box for any number of v5e-256 slices.
#
# Command Flags:
# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
# DATASET_PATH (Required, unless dataset_path is already set in base.yml)
# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
#
# Example to invoke this script:
# bash MaxText/configs/v5e/64b.sh RUN_NAME="<your_run_name>" OUTPUT_PATH="gs://<your_output_path>" DATASET_PATH="gs://<your_dataset_path>"
#
# Example to AOT compile:
# bash MaxText/configs/v5e/64b.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2
# Stop execution if any command exits with error
set -e
export EXECUTABLE="train.py" # or train_compile.py
export RUN_PREFLIGHT="true"
# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done
# The setup accommodates two cases:
# 1) Passing the 'RUN_NAME' variable at runtime
# 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
if [ -n "$RUN_NAME" ];
then
export M_RUN_NAME=$RUN_NAME
fi
# Set up network optimizations
if [ "$RUN_PREFLIGHT" = "true" ]; then
bash preflight.sh
fi
# Train
export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
python3 MaxText/$EXECUTABLE MaxText/configs/base.yml\
steps=15 per_device_batch_size=2 enable_checkpointing=false\
remat_policy=full global_parameter_scale=64\
max_target_length=2048 base_output_directory=$OUTPUT_PATH\
dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
dataset_type=synthetic attention='flash' gcs_metrics=true