@@ -13,10 +13,11 @@ usage() {
1313 echo " "
1414 echo " OPTIONS DESCRIPTION"
1515 echo " -a, --additional-args Additional fiddle args to pass to MaxText/train.py"
16+ echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65"
17+ echo " --model-name Specify model to run. Example: llama2-7b, default"
18+ echo " --attn-type Specify the attention type. Example: dot_product, cudnn_flash_te"
1619 echo " -b, --batch-per-gpu Batch size per GPU, defaults to 2."
1720 echo " --dtype Batch size, defaults to bfloat16."
18- echo " --enable-te If set, will run with env var ENABLE_TE=1."
19- echo " --enable-fused-attn If set, will run with env var NVTE_FUSED_ATTN=1."
2021 echo " -s, --steps Number of steps to run, defaults to 500."
2122 echo " --multiprocess Enable the multiprocess GPU mode."
2223 echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified."
@@ -29,15 +30,18 @@ usage() {
2930 exit $1
3031}
3132
32- args=$( getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-fused-attn ,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- " $@ " )
33+ args=$( getopt -o a:b:s:o:n:h --long additional-args:,mem-fraction:,model-name:,attn-type:,batch-per-gpu:,dtype: ,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- " $@ " )
3334if [[ $? -ne 0 ]]; then
3435 exit $1
3536fi
3637
3738# Default arguments
3839HARDWARE=' gpu'
3940OUTPUT=$( mktemp -d)
41+ MEM_FRACTION=0.65
4042
43+ MODEL_NAME=' llama2-7b'
44+ ATTN_TYPE=' dot_product'
4145BATCH_PER_GPU=2
4246DTYPE=" bfloat16"
4347STEPS=10
4650TP=1
4751PP=1
4852NODES=1
49- ENABLE_TE=0
5053ENABLE_FUSED_ATTN=0
5154ADDITIONAL_ARGS=" "
5255
@@ -57,6 +60,18 @@ while [ : ]; do
5760 ADDITIONAL_ARGS=" $2 "
5861 shift 2
5962 ;;
63+ --mem-fraction)
64+ MEM_FRACTION=" $2 "
65+ shift 2
66+ ;;
67+ --model-name)
68+ MODEL_NAME=" $2 "
69+ shift 2
70+ ;;
71+ --attn-type)
72+ ATTN_TYPE=" $2 "
73+ shift 2
74+ ;;
6075 -b | --batch-per-gpu)
6176 BATCH_PER_GPU=" $2 "
6277 shift 2
@@ -130,13 +145,20 @@ else
130145 ici_DP=$DP
131146fi
132147
148+ if [ $ATTN_TYPE -eq ' cudnn_flash_te' ]
149+ then
150+ ENABLE_FUSED_ATTN=1
151+ fi
152+
153+ print_var MEM_FRACTION
154+ print_var MODEL_NAME
155+ print_var ATTN_TYPE
133156print_var BATCH_PER_GPU
134157print_var DTYPE
135158print_var STEPS
136159print_var NGPUS
137160print_var HARDWARE
138161print_var OUTPUT
139- print_var ENABLE_TE
140162print_var ENABLE_FUSED_ATTN
141163print_var DP
142164print_var ici_DP
@@ -152,10 +174,10 @@ pushd ${MAXTEXT_DIR}
152174set -ex
153175
154176export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
155- export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
177+ export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
156178export CUDA_DEVICE_MAX_CONNECTIONS=1
157179
158- export XLA_FLAGS= " --xla_gpu_enable_latency_hiding_scheduler=true
180+ export BASE_XLA_FLAGS= ${BASE_XLA_FLAGS :- --xla_gpu_enable_latency_hiding_scheduler=true
159181 --xla_gpu_enable_async_all_gather=true
160182 --xla_gpu_enable_async_reduce_scatter=true
161183 --xla_gpu_enable_triton_gemm=false
@@ -173,12 +195,14 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
173195 --xla_gpu_enable_triton_softmax_fusion=false
174196 --xla_gpu_enable_all_gather_combine_by_dim=false
175197 --xla_gpu_enable_reduce_scatter_combine_by_dim=false
176- --xla_disable_hlo_passes=rematerialization"
198+ --xla_disable_hlo_passes=rematerialization}
199+
200+ export XLA_FLAGS=" $BASE_XLA_FLAGS ${XLA_FLAGS:- } "
177201
178202RUN_NAME=" logdir" # # the RUN_NAME cannot be changed
179203
180- RUN_SETTINGS=" MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} logits_via_embedding=true decoder_block=default \
181- steps=$STEPS per_device_batch_size=2 base_emb_dim=2560 base_mlp_dim=8192 remat_policy=minimal attention=dot_product \
204+ RUN_SETTINGS=" MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} logits_via_embedding=true decoder_block=${MODEL_NAME} \
205+ steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} base_emb_dim=2560 base_mlp_dim=8192 remat_policy=minimal attention=${ATTN_TYPE} \
182206 base_num_query_heads=8 base_num_kv_heads=8 base_num_decoder_layers=8 head_dim=128 enable_checkpointing=false\
183207 base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE \
184208 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=$FSDP \
0 commit comments