Skip to content

Commit e44e507

Browse files
authored
test-maxtext.sh: support user-defined XLA flags (#763)
1. added support for used defined XLA flags 2. changed source of PGO converter script to jax main
1 parent 46ab5a1 commit e44e507

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

.github/container/test-maxtext.sh

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ usage() {
1313
echo ""
1414
echo " OPTIONS DESCRIPTION"
1515
echo " -a, --additional-args Additional fiddle args to pass to MaxText/train.py"
16+
echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65"
17+
echo " --model-name Specify model to run. Example: llama2-7b, default"
18+
echo " --attn-type Specify the attention type. Example: dot_product, cudnn_flash_te"
1619
echo " -b, --batch-per-gpu Batch size per GPU, defaults to 2."
1720
echo " --dtype Batch size, defaults to bfloat16."
18-
echo " --enable-te If set, will run with env var ENABLE_TE=1."
19-
echo " --enable-fused-attn If set, will run with env var NVTE_FUSED_ATTN=1."
2021
echo " -s, --steps Number of steps to run, defaults to 500."
2122
echo " --multiprocess Enable the multiprocess GPU mode."
2223
echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified."
@@ -29,15 +30,18 @@ usage() {
2930
exit $1
3031
}
3132

32-
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-fused-attn,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
33+
args=$(getopt -o a:b:s:o:n:h --long additional-args:,mem-fraction:,model-name:,attn-type:,batch-per-gpu:,dtype:,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
3334
if [[ $? -ne 0 ]]; then
3435
exit $1
3536
fi
3637

3738
# Default arguments
3839
HARDWARE='gpu'
3940
OUTPUT=$(mktemp -d)
41+
MEM_FRACTION=0.65
4042

43+
MODEL_NAME='llama2-7b'
44+
ATTN_TYPE='dot_product'
4145
BATCH_PER_GPU=2
4246
DTYPE="bfloat16"
4347
STEPS=10
@@ -46,7 +50,6 @@ FSDP=1
4650
TP=1
4751
PP=1
4852
NODES=1
49-
ENABLE_TE=0
5053
ENABLE_FUSED_ATTN=0
5154
ADDITIONAL_ARGS=""
5255

@@ -57,6 +60,18 @@ while [ : ]; do
5760
ADDITIONAL_ARGS="$2"
5861
shift 2
5962
;;
63+
--mem-fraction)
64+
MEM_FRACTION="$2"
65+
shift 2
66+
;;
67+
--model-name)
68+
MODEL_NAME="$2"
69+
shift 2
70+
;;
71+
--attn-type)
72+
ATTN_TYPE="$2"
73+
shift 2
74+
;;
6075
-b | --batch-per-gpu)
6176
BATCH_PER_GPU="$2"
6277
shift 2
@@ -130,13 +145,20 @@ else
130145
ici_DP=$DP
131146
fi
132147

148+
if [ $ATTN_TYPE -eq 'cudnn_flash_te' ]
149+
then
150+
ENABLE_FUSED_ATTN=1
151+
fi
152+
153+
print_var MEM_FRACTION
154+
print_var MODEL_NAME
155+
print_var ATTN_TYPE
133156
print_var BATCH_PER_GPU
134157
print_var DTYPE
135158
print_var STEPS
136159
print_var NGPUS
137160
print_var HARDWARE
138161
print_var OUTPUT
139-
print_var ENABLE_TE
140162
print_var ENABLE_FUSED_ATTN
141163
print_var DP
142164
print_var ici_DP
@@ -152,10 +174,10 @@ pushd ${MAXTEXT_DIR}
152174
set -ex
153175

154176
export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
155-
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
177+
export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
156178
export CUDA_DEVICE_MAX_CONNECTIONS=1
157179

158-
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
180+
export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
159181
--xla_gpu_enable_async_all_gather=true
160182
--xla_gpu_enable_async_reduce_scatter=true
161183
--xla_gpu_enable_triton_gemm=false
@@ -173,12 +195,14 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
173195
--xla_gpu_enable_triton_softmax_fusion=false
174196
--xla_gpu_enable_all_gather_combine_by_dim=false
175197
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
176-
--xla_disable_hlo_passes=rematerialization"
198+
--xla_disable_hlo_passes=rematerialization}
199+
200+
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
177201

178202
RUN_NAME="logdir" ## the RUN_NAME cannot be changed
179203

180-
RUN_SETTINGS="MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} logits_via_embedding=true decoder_block=default \
181-
steps=$STEPS per_device_batch_size=2 base_emb_dim=2560 base_mlp_dim=8192 remat_policy=minimal attention=dot_product\
204+
RUN_SETTINGS="MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} logits_via_embedding=true decoder_block=${MODEL_NAME} \
205+
steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} base_emb_dim=2560 base_mlp_dim=8192 remat_policy=minimal attention=${ATTN_TYPE}\
182206
base_num_query_heads=8 base_num_kv_heads=8 base_num_decoder_layers=8 head_dim=128 enable_checkpointing=false\
183207
base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE\
184208
dcn_fsdp_parallelism=1 ici_fsdp_parallelism=$FSDP\

.github/workflows/_test_maxtext.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ jobs:
117117
test-maxtext.sh \
118118
--output /output/${{ steps.meta.outputs.TEST_CASE_NAME }} \
119119
--dtype bfloat16 \
120+
--mem-fraction 0.65 \
121+
--model-name default \
122+
--attn-type dot_product \
120123
--batch-per-gpu 2 \
121124
--steps 10 \
122125
--pipeline-parallel ${{ matrix.PARALLEL_CONFIG[0] }} \
@@ -267,6 +270,9 @@ jobs:
267270
test-maxtext.sh \
268271
--output /output/${{ steps.meta.outputs.TEST_CASE_NAME }} \
269272
--dtype bfloat16 \
273+
--mem-fraction 0.65 \
274+
--model-name default \
275+
--attn-type dot_product \
270276
--batch-per-gpu 2 \
271277
--steps 10 \
272278
--pipeline-parallel ${{ matrix.PARALLEL_CONFIG[0] }} \

rosetta/docs/PGLE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false
1414
```
1515
The main reason to do this is to not have any overlaps so that we can get exact costs for different ops.
1616

17-
2. **Generate protobuf**: Once we have the nsys profile generated, we pass it to the python script provided [here [pgo_nsys_converter.py]](https://github.com/abhinavgoel95/jax/blob/patch-1/jax/tools/pgo_nsys_converter.py) to generate the pbtxt file. A sample pbtxt file would look like this:
17+
2. **Generate protobuf**: Once we have the nsys profile generated, we pass it to the python script provided [here [pgo_nsys_converter.py]](https://github.com/google/jax/blob/main/jax/tools/pgo_nsys_converter.py) to generate the pbtxt file. A sample pbtxt file would look like this:
1818
```
1919
...
2020
costs { name: "all-gather-start.1" cost_us: 7040.5215 }

0 commit comments

Comments
 (0)