diff --git a/mlperf/accuracy_run.sh b/mlperf/accuracy_run.sh new file mode 100644 index 0000000..b940ae7 --- /dev/null +++ b/mlperf/accuracy_run.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +me=$(basename "$0") + +BASEDIR=mlperf +API_URL=0.0.0.0:9000 +USER_CONFIG=$BASEDIR/user.conf +DATA_DISK_DIR=$BASEDIR/data +TOTAL_SAMPLE_COUNT=1000 +DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl + +# HF model id +TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" +LOADGEN_RUN_TYPE=offline-performance +OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} +OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json + +CACHE_LENGTH=1024 +INPUT_SIZE=512 +OUTPUT_SIZE=512 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +# makes subsequent runs faster +export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" +export LIBTPU_INIT_ARGS + +pushd .. +# python -m mlperf.offline_mode \ +# --model_name=mixtral \ +# --max_cache_length=$CACHE_LENGTH \ +# --max_decode_length=$OUTPUT_SIZE \ +# --context_length=$INPUT_SIZE \ +# --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ +# --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ +# --quantize_weights=1 \ +# --quantize_type=int8_per_channel \ +# --quantize_kv_cache=1 \ +# --scenario Offline \ +# --input_mode tokenized \ +# --output_mode tokenized \ +# --mlperf_conf $BASEDIR/mlperf.conf \ +# --user_conf ${USER_CONFIG} \ +# --audit_conf no_audit \ +# --total_sample_count ${TOTAL_SAMPLE_COUNT} \ +# --dataset_path ${DATASET_PATH} \ +# --output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log + +python -m mlperf.evaluate_accuracy \ + --checkpoint-path ${TOKENIZER_PATH} \ + --mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ + --dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log +popd \ No newline at end of file diff --git a/mlperf/evaluate_accuracy.py b/mlperf/evaluate_accuracy.py new file mode 100644 index 0000000..cf1537a --- /dev/null +++ b/mlperf/evaluate_accuracy.py @@ -0,0 +1,252 @@ +import argparse +from transformers import AutoTokenizer +import nltk +import evaluate +import numpy as np +import pandas as pd +import json +import re + +import logging +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger("evaluate_accuracy.py") + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint-path", + required=True, + help="Path to Mixtral-8x7b-Instruct checkpoint", + ) + parser.add_argument( + "--mlperf-accuracy-file", + required=True, + help="path to mlperf_log_accuracy.json", + ) + parser.add_argument( + "--dataset-file", + required=True, + help="path to processed validation dataset", + ) + parser.add_argument( + "--n_workers", + default=2, + type=int, + help="Number of workers used for the MBXP evaluation", + ) + parser.add_argument("--verbose", action="store_true", help="verbose messages") + parser.add_argument( + "--dtype", + default="int64", + help="dtype of the accuracy log", + choices=["int32", "int64", "float"], + ) + args = parser.parse_args() + return args + + +def get_groundtruth(processed_dataset_file): + data = pd.read_pickle(processed_dataset_file) + return data + + +# Functions for evaluating GSM8K +def find_numbers(x: str) -> list[str]: + """Finds all numbers in a string.""" + # Search for number, possibly negative (hyphen), with thousand separators + # (comma), and with a decimal point (period inbetween digits). + numbers = re.compile( + r"-?[\d,]*\.?\d+", + re.MULTILINE | re.DOTALL | re.IGNORECASE, + ).findall(x) + return numbers + + +def find_number(x: str, answer_delimiter: str = "The answer is") -> str: + """Finds the most relevant number in a string.""" + # If model uses the answer delimiter, then select the first number following + # that format. + if answer_delimiter in x: + answer = x.split(answer_delimiter)[-1] + numbers = find_numbers(answer) + if numbers: + return numbers[0] + + # In general, select the last number in the string. + numbers = find_numbers(x) + if numbers: + return numbers[-1] + return "" + + +def maybe_remove_comma(x: str) -> str: + # Example: 5,600 -> 5600 + return x.replace(",", "") + + +def try_float(x: str): + try: + ret = float(x) + except BaseException: + ret = None + return ret + + +# Functions for evaluating OpenOrca + + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + + +# Functions for MBXP + + +def create_mbxp_dict(row, response): + lang, entry_point = row["id"].split("_", 1) + return { + "lang": lang, + "prompt": row["input"], + "test_code": row["gt_output"], + "entry_point": entry_point, + "response": response, + } + + +def main(): + + args = get_args() + dataset_path = args.dataset_file + checkpoint_path = args.checkpoint_path + metric = evaluate.load("rouge") + nltk.download("punkt") + + tokenizer = AutoTokenizer.from_pretrained( + checkpoint_path, + model_max_length=2048, + padding_side="left", + use_fast=False, + ) + + data = get_groundtruth(args.dataset_file) + query_types, gt_outputs = data["dataset"], data["gt_output"] + + target_required_GSM8K = [] + target_required_OpenOrca = [] + results_MBXP = [] + preds_token_GSM8K = [] + preds_token_OpenOrca = [] + preds_token_MBXP = [] + + eval_dtype = np.int64 + if args.dtype == "int32": + eval_dtype = np.int32 + elif args.dtype == "float": + eval_dtype = np.float32 + + with open(args.mlperf_accuracy_file, "r") as f: + results = json.load(f) + + seen = set() + gen_tok_len = 0 + gen_num = 0 + for pred in results: + gen_num += 1 + qsl_idx = pred["qsl_idx"] + if qsl_idx in seen: + continue + + seen.add(qsl_idx) + + query_type = query_types.iloc[qsl_idx] + if query_type == "GSM8K": + target = gt_outputs.iloc[qsl_idx] + target_required_GSM8K.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + gen_tok_len += len(pred) + preds_token_GSM8K.append(pred) + elif query_type == "OpenOrca": + target = gt_outputs.iloc[qsl_idx] + target_required_OpenOrca.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + preds_token_OpenOrca.append(pred) + gen_tok_len += len(pred) + else: + target = data.iloc[qsl_idx] + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + pred_str = tokenizer.decode(pred, skip_special_tokens=True) + results_MBXP.append(create_mbxp_dict(target, pred_str)) + gen_tok_len += len(pred) + + # OpenOrca metric + preds_decoded_text = tokenizer.batch_decode( + preds_token_OpenOrca, skip_special_tokens=True + ) + + preds, targets = postprocess_text( + preds_decoded_text, target_required_OpenOrca + ) + + if preds: + result = metric.compute( + predictions=preds, + references=targets, + use_stemmer=True, + use_aggregator=False, + ) + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in preds] + + else: + result = {} + prediction_lens = [] + + import ipdb; ipdb.set_trace() + # GSM8K metric + preds_decoded_text = tokenizer.batch_decode( + preds_token_GSM8K, skip_special_tokens=True + ) + pred_nums = [ + maybe_remove_comma(find_number(pred_text.split("\nQ:")[0])) + for pred_text in preds_decoded_text + ] + gsm8k_total = len(target_required_GSM8K) + correct = 0 + for idx in range(len(target_required_GSM8K)): + ref = try_float(target_required_GSM8K[idx]) + tgt = try_float(pred_nums[idx]) + if tgt is None: + continue + correct += ref == tgt + + result["gsm8k"] = 100.0 * correct / gsm8k_total + + # MBXP metric + # from evaluate_mbxp import evaluate_mbxp + + # if results_MBXP: + # result['mbxp'] = evaluate_mbxp(results_MBXP, args.n_workers) + # else: + # result['mbxp'] = 0 + + result = { + **result, + "gen_len": np.sum(prediction_lens), + "gen_num": gen_num, + "gen_tok_len": gen_tok_len, + "tokens_per_sample": round(gen_tok_len / gen_num, 1), + } + + print("\nResults\n") + print(result) + + +if __name__ == "__main__": + main() diff --git a/mlperf/install.sh b/mlperf/install.sh new file mode 100644 index 0000000..3a8f037 --- /dev/null +++ b/mlperf/install.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +DATA_DISK_DIR=data + +mkdir -p $DATA_DISK_DIR + +pip install -U "huggingface_hub[cli]" +pip install \ + transformers \ + nltk==3.8.1 \ + evaluate==0.4.0 \ + absl-py==1.4.0 \ + rouge-score==0.1.2 \ + sentencepiece==0.1.99 \ + accelerate==0.21.0 + +# install loadgen +pip install mlperf-loadgen + + +pushd $DATA_DISK_DIR + +# model weights +gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive +# NOTE: uncomment one so you dont download too much weights to your box +# gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive + +# Get mixtral data +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl + +# Get llama70b data +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ + processed-calibration-data.pkl +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \ + processed-data.pkl +popd diff --git a/mlperf/mixtral_run.sh b/mlperf/mixtral_run.sh index f2304ca..7750492 100755 --- a/mlperf/mixtral_run.sh +++ b/mlperf/mixtral_run.sh @@ -4,7 +4,7 @@ me=$(basename "$0") BASEDIR=mlperf USER_CONFIG=$BASEDIR/user.conf DATA_DISK_DIR=$BASEDIR/data -TOTAL_SAMPLE_COUNT=1000 +TOTAL_SAMPLE_COUNT=900 # HF model id TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" diff --git a/mlperf/mlperf.conf b/mlperf/mlperf.conf new file mode 100644 index 0000000..e9ae205 --- /dev/null +++ b/mlperf/mlperf.conf @@ -0,0 +1,98 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +# User can optionally set this to higher values in user.conf. +resnet50.*.performance_sample_count_override = 1024 +ssd-mobilenet.*.performance_sample_count_override = 256 +retinanet.*.performance_sample_count_override = 64 +bert.*.performance_sample_count_override = 10833 +dlrm.*.performance_sample_count_override = 204800 +dlrm-v2.*.performance_sample_count_override = 204800 +rnnt.*.performance_sample_count_override = 2513 +gptj.*.performance_sample_count_override = 13368 +llama2-70b.*.performance_sample_count_override = 24576 +stable-diffusion-xl.*.performance_sample_count_override = 5000 +# set to 0 to let entire sample set to be performance sample +3d-unet.*.performance_sample_count_override = 0 + +# Set seeds. The seeds will be distributed two weeks before the submission. +*.*.qsl_rng_seed = 3066443479025735752 +*.*.sample_index_rng_seed = 10688027786191513374 +*.*.schedule_rng_seed = 14962580496156340209 +# Set seeds for TEST_05. The seeds will be distributed two weeks before the submission. +*.*.test05_qsl_rng_seed = 16799458546791641818 +*.*.test05_sample_index_rng_seed = 5453809927556429288 +*.*.test05_schedule_rng_seed = 5435552105434836064 + + +*.SingleStream.target_latency_percentile = 90 +*.SingleStream.min_duration = 600000 + +*.MultiStream.target_latency_percentile = 99 +*.MultiStream.samples_per_query = 8 +*.MultiStream.min_duration = 600000 +*.MultiStream.min_query_count = 662 +retinanet.MultiStream.target_latency = 528 + +# 3D-UNet uses equal issue mode because it has non-uniform inputs +3d-unet.*.sample_concatenate_permutation = 1 + +# LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario +gptj.*.sample_concatenate_permutation = 1 +llama2-70b.*.sample_concatenate_permutation = 1 +mixtral-8x7B.*.sample_concatenate_permutation = 1 + +*.Server.target_latency = 10 +*.Server.target_latency_percentile = 99 +*.Server.target_duration = 0 +*.Server.min_duration = 600000 +resnet50.Server.target_latency = 15 +retinanet.Server.target_latency = 100 +bert.Server.target_latency = 130 +dlrm.Server.target_latency = 60 +dlrm-v2.Server.target_latency = 60 +rnnt.Server.target_latency = 1000 +gptj.Server.target_latency = 20000 +stable-diffusion-xl.Server.target_latency = 20000 +# Llama2-70b benchmarks measures token latencies +llama2-70b.*.use_token_latencies = 1 +mixtral-8x7b.*.use_token_latencies = 1 +# gptj benchmark infers token latencies +gptj.*.infer_token_latencies = 1 +gptj.*.token_latency_scaling_factor = 69 +# Only ttft and tpot are tracked for the llama2-70b & mixtral-8x7B benchmark therefore target_latency = 0 +llama2-70b.Server.target_latency = 0 +llama2-70b.Server.ttft_latency = 2000 +llama2-70b.Server.tpot_latency = 200 + +mixtral-8x7b.Server.target_latency = 0 +mixtral-8x7b.Server.ttft_latency = 2000 +mixtral-8x7b.Server.tpot_latency = 200 + +*.Offline.target_latency_percentile = 90 +*.Offline.min_duration = 600000 + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit + +resnet50.Offline.min_query_count = 24576 +retinanet.Offline.min_query_count = 24576 +dlrm-v2.Offline.min_query_count = 24576 +bert.Offline.min_query_count = 10833 +gptj.Offline.min_query_count = 13368 +rnnt.Offline.min_query_count = 2513 +3d-unet.Offline.min_query_count = 43 +stable-diffusion-xl.Offline.min_query_count = 5000 +llama2-70b.Offline.min_query_count = 1000 +mixtral-8x7b.Offline.min_query_count = 15000 + +# These fields should be defined and overridden by user.conf. +*.SingleStream.target_latency = 10 +*.MultiStream.target_latency = 80 +*.Server.target_qps = 1.0 +*.Offline.target_qps = 4.0 diff --git a/mlperf/offline_mode.py b/mlperf/offline_mode.py index b71eebf..8711f85 100644 --- a/mlperf/offline_mode.py +++ b/mlperf/offline_mode.py @@ -383,7 +383,7 @@ def main(argv): "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet" ) elif FLAGS.mlperf_test_mode == "submission": - settings.mode = lg.TestMode.Submission + settings.mode = lg.TestMode.SubmissionRun settings.print_timestamps = True else: settings.mode = lg.TestMode.PerformanceOnly diff --git a/mlperf/user.conf b/mlperf/user.conf new file mode 100644 index 0000000..95ef75e --- /dev/null +++ b/mlperf/user.conf @@ -0,0 +1,6 @@ +mixtral-8x7b.Server.target_qps = 2.0 +mixtral-8x7b.Offline.target_qps = 100.0 + +# send unique queries +mixtral-8x7b.Offline.performance_issue_unique = 1 +