Skip to content

Commit

Permalink
Merge pull request #1264 from AI-Hypercomputer:msingh-bm
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725771475
  • Loading branch information
maxtext authors committed Feb 11, 2025
2 parents 22840a3 + 3fbb8ef commit 39b5b74
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 101 deletions.
67 changes: 28 additions & 39 deletions MaxText/inference_mlperf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,76 +102,65 @@ export HUGGING_FACE_TOKEN=<your_hugging_face_token>
huggingface-cli login --token $HUGGING_FACE_TOKEN
```

### Offline Server - Test Run
### Run Offline Benchmarks

#### For trillium
#### LLama2-70b:
```
cd ~/maxtext/MaxText/inference_mlperf
export TOKENIZER_PATH="/home/${USER}/maxtext/assets/tokenizer.llama2
export BATCH_AND_PREFILL_LEN="1024,20"
export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True"
cd ~/maxtext/MaxText/inference_mlperf/trillium
```

bash ./llama_offline_run.sh -p -t
##### Test Run
```
#### Mixtral-8x7b:
bash benchmarks_llama2-70b-trillium_2x4.sh -b=performance -t
```
cd ~/maxtext/MaxText/inference_mlperf
export TOKENIZER_PATH="/home/${USER}/maxtext/assets/tokenizer.mistral-v1
export BATCH_AND_PREFILL_LEN="2048,18"
export MAXENGINE_ARGS="model_name=mixtral-8x7b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True megablox=False sparse_matmul=False capacity_factor=1 model_call_mode=inference"

bash ./mixtral_offline_run.sh -p -t
##### Performance Only:
```
bash benchmarks_llama2-70b-trillium_2x4.sh -b=performance
```

### Offline Benchmarks

#### For v5e
##### Accuracy Only:
```
export BATCH_AND_PREFILL_LEN="256,80|512,40|1024,20"
export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True compute_axis_order=0,1,2,3 ar_cache_axis_order=0,1,2,3"
bash benchmarks_llama2-70b-trillium_2x4.sh -b=accuracy
```

#### For v6
#### LLama2-70b:
##### Audit Only:
```
export BATCH_AND_PREFILL_LEN=“256,216|512,108|1024,54”
export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3"
bash benchmarks_llama2-70b-trillium_2x4.sh -b=audit
```
#### Mixtral-8x7b:
export BATCH_AND_PREFILL_LEN="256,144|512,72|2048,18"
export MAXENGINE_ARGS="model_name=mixtral-8x7b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True megablox=False sparse_matmul=False capacity_factor=1 model_call_mode=inference compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3"

#### Run offline performance benchmark
#### LLama2-70b:
##### Run all benchmarks:
```
bash ./llama_offline_run.sh -p
bash benchmarks_llama2-70b-trillium_2x4.sh -b=all
```

#### Mixtral-8x7b:
```
bash ./mixtral_offline_run.sh -p
export PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES="256,144|512,72|2048,18"
export MAXENGINE_ARGS="model_name=mixtral-8x7b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True megablox=False sparse_matmul=False capacity_factor=1 model_call_mode=inference compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3"
```

#### Run offline accuracy benchmark
#### LLama2-70b:
##### Test Run
```
bash ./llama_offline_run.sh -a
bash ./mixtral_offline_run.sh -t
```
#### Mixtral-8x7b:

##### Performance Only:
```
bash ./mixtral_offline_run.sh -a
bash ./mixtral_offline_run.sh
```

#### Run offline audit benchmark
#### LLama2-70b:
##### Accuracy Only:
```
bash ./llama_offline_run.sh -d
bash ./mixtral_offline_run.sh -a
```
#### Mixtral-8x7b:

##### Audit Only:
```
bash ./mixtral_offline_run.sh -d
```


### Profiling

```
Expand Down
46 changes: 31 additions & 15 deletions MaxText/inference_mlperf/llama_offline_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
# enable profiling using -p option and capture using
# tensorboard --logdir /tmp/tensorboard/

run_name="test_int8_kv_bs_216-108-54"
run_name="llama_offline_benchmarks"
dry_run=false
skip_warmup=false
test_run=false
enable_profiler=false
enable_batch_prefill=false
performance=true
audit=false
accuracy=false
fast_eval=false
enable_batch_prefill=false

for arg in "$@"; do
case $arg in
Expand All @@ -38,9 +38,9 @@ done


if "$dry_run"; then
cmd=echo
CMD=echo
else
cmd=''
CMD=''
fi

SKIP_WARMUP_OPTION=""
Expand All @@ -63,23 +63,34 @@ if [ -z "$TOKENIZER_PATH" ]; then
fi

BATCH_STR=""
if [ -z "$BATCH_AND_PREFILL_LEN" ];
if [ -z "$PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES" ];
then
BATCH_AND_PREFILL_LEN="256,216|512,108|1024,54"
PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES="256,216|512,108|1024,54"
fi

if [ -z "$TOK_OUTLEN_MULTIPLIER" ];
then
TOK_OUTLEN_MULTIPLIER="2.5"
fi

if [ -z "$MODEL_NAME" ];
then
MODEL_NAME="llama2-70b"
fi

if [ -z "$HF_CKPT" ];
then
HF_CKPT="meta-llama/Llama-2-70b-chat-hf"
fi



if [ -z "$MAXENGINE_ARGS" ];
then
CHECKPOINT="gs://msingh-bkt/checkpoints/quant_llama2-70b-chat/mlperf_070924/int8_"
BASE_CFG="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${CHECKPOINT}"
CHECKPOINT="gs://msingh-bkt/checkpoints/quant_${MODEL_NAME}-chat/mlperf_070924/int8_"
BASE_CFG="model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${CHECKPOINT}"
QUANT_CFG="quantization=int8 quantize_kvcache=True checkpoint_is_quantized=True"
LAYOUT_CFG="compute_axis_order=0,1,2,3 ar_cache_axis_order=0,1,2,3"
MAXENGINE_ARGS="${BASE_CFG} ${QUANT_CFG} ${LAYOUT_CFG}"
MAXENGINE_ARGS="${BASE_CFG} ${QUANT_CFG}"
fi

if [ -z "$BASEDIR" ];
Expand Down Expand Up @@ -122,10 +133,10 @@ run_loadgen() {
echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}"
echo "OUTPUT_LOG_DIR: ${OUTPUT_LOG_DIR}"
echo "USER_CONFIG: ${USER_CONFIG}"
echo "BATCH_AND_PREFILL_LEN: ${BATCH_AND_PREFILL_LEN}"
echo "PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES: ${PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES}"
echo "MAXENGINE_ARGS: ${MAXENGINE_ARGS}"

${cmd} python -m offline_mode \
echo
${CMD} python -m offline_mode \
--mlperf_test_mode=${TEST_MODE} \
--input_mode tokenized \
--output_mode tokenized \
Expand All @@ -134,7 +145,7 @@ run_loadgen() {
--audit_conf ${AUDIT_CONF} \
--total_sample_count ${TOTAL_SAMPLE_COUNT} \
--dataset_path ${DATASET_PATH} \
--prefill_lengths_and_batch_sizes ${BATCH_AND_PREFILL_LEN} \
--prefill_lengths_and_per_device_batch_sizes ${PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES} \
--maxengine_args "${MAXENGINE_ARGS}" \
--output_log_dir ${OUTPUT_LOG_DIR} \
--tok_outlen_multiplier ${TOK_OUTLEN_MULTIPLIER} \
Expand All @@ -161,15 +172,20 @@ run_loadgen_accuracy () {
AUDIT_CONF="no_audit"
run_loadgen

if [ $dry_run ] ; then
touch ${OUTPUT_ACCURACY_JSON_PATH}
fi

# Eval Run
if [ -e ${OUTPUT_ACCURACY_JSON_PATH} ]; then
if [ "${FAST_EVAL:-false}" = "true" ] || "$fast_eval"; then
EVAL_SCRIPT="evaluate-accuracy-fast.py"
else
EVAL_SCRIPT="evaluate-accuracy.py"
fi

echo
${CMD} python3 ${EVAL_SCRIPT} \
--checkpoint-path ${HF_CKPT} \
--tokenizer-path ${TOKENIZER_PATH} \
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log
Expand Down
8 changes: 4 additions & 4 deletions MaxText/inference_mlperf/mixtral_offline_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ if [ -z "$TOKENIZER_PATH" ]; then
fi

BATCH_STR=""
if [ -z "$BATCH_AND_PREFILL_LEN" ];
if [ -z "$PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES" ];
then
BATCH_AND_PREFILL_LEN="256,144|512,72|2048,18"
PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES="256,144|512,72|2048,18"
fi

if [ -z "$TOK_OUTLEN_MULTIPLIER" ];
Expand Down Expand Up @@ -117,7 +117,7 @@ run_loadgen() {
echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}"
echo "OUTPUT_LOG_DIR: ${OUTPUT_LOG_DIR}"
echo "USER_CONFIG: ${USER_CONFIG}"
echo "BATCH_AND_PREFILL_LEN: ${BATCH_AND_PREFILL_LEN}"
echo "PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES: ${PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES}"
echo "MAXENGINE_ARGS: ${MAXENGINE_ARGS}"

${cmd} python -m offline_mode \
Expand All @@ -129,7 +129,7 @@ run_loadgen() {
--audit_conf ${AUDIT_CONF} \
--total_sample_count ${TOTAL_SAMPLE_COUNT} \
--dataset_path ${DATASET_PATH} \
--prefill_lengths_and_batch_sizes ${BATCH_AND_PREFILL_LEN} \
--prefill_lengths_and_batch_sizes ${PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES} \
--maxengine_args "${MAXENGINE_ARGS}" \
--output_log_dir ${OUTPUT_LOG_DIR} \
--tok_outlen_multiplier ${TOK_OUTLEN_MULTIPLIER} \
Expand Down
8 changes: 4 additions & 4 deletions MaxText/inference_mlperf/offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"performance",
"performance, accuracy, submission",
)
flags.DEFINE_string("api_url", None, "SAX published model path.", required=False)
flags.DEFINE_string("api_url", None, "published model path.", required=False)
flags.DEFINE_string("dataset_path", None, "", required=False)
flags.DEFINE_bool("is_stream", False, "", required=False)
flags.DEFINE_string(
Expand Down Expand Up @@ -119,7 +119,7 @@
required=False,
)
flags.DEFINE_string(
"prefill_lengths_and_batch_sizes",
"prefill_lengths_and_per_device_batch_sizes",
"256,80|512,40|1024,20",
"List of prefill lengths and batch sizes to use for each engine. Format len_1,bs_1|len_2,bs_2|..",
required=False,
Expand Down Expand Up @@ -198,7 +198,7 @@ def pad_tokens(tokens):

def _init_query_batches():
query_batches = {}
len_batch_str = FLAGS.prefill_lengths_and_batch_sizes.split("|")
len_batch_str = FLAGS.prefill_lengths_and_per_device_batch_sizes.split("|")
len_batch = []
for lb in len_batch_str:
l, b = lb.split(",")
Expand Down Expand Up @@ -449,7 +449,7 @@ def main(argv):
log.info(f"Dataset len {len(dataset)}, estimated counts by bucket {estimated_counts_by_bucket}")

rows = list(dataset.iterrows())
len_batch_str = FLAGS.prefill_lengths_and_batch_sizes
len_batch_str = FLAGS.prefill_lengths_and_per_device_batch_sizes
log.info(f"Prefill lengths and Batch sizes: {len_batch_str}")
log.info(f"Maxengine args: {FLAGS.maxengine_args}")

Expand Down
1 change: 0 additions & 1 deletion MaxText/inference_mlperf/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ absl-py==1.4.0
rouge-score==0.1.2
sentencepiece==0.1.99
accelerate==0.21.0
orbax-checkpoint==0.5.20
Loading

0 comments on commit 39b5b74

Please sign in to comment.