-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
455 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/usr/bin/env bash | ||
me=$(basename "$0") | ||
|
||
BASEDIR=mlperf | ||
API_URL=0.0.0.0:9000 | ||
USER_CONFIG=$BASEDIR/user.conf | ||
DATA_DISK_DIR=$BASEDIR/data | ||
TOTAL_SAMPLE_COUNT=1000 | ||
DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl | ||
|
||
# HF model id | ||
TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" | ||
LOADGEN_RUN_TYPE=offline-performance | ||
OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} | ||
OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} | ||
|
||
mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} | ||
|
||
OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json | ||
|
||
CACHE_LENGTH=1024 | ||
INPUT_SIZE=512 | ||
OUTPUT_SIZE=512 | ||
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ | ||
|
||
LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" | ||
# makes subsequent runs faster | ||
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" | ||
export LIBTPU_INIT_ARGS | ||
|
||
pushd .. | ||
# python -m mlperf.offline_mode \ | ||
# --model_name=mixtral \ | ||
# --max_cache_length=$CACHE_LENGTH \ | ||
# --max_decode_length=$OUTPUT_SIZE \ | ||
# --context_length=$INPUT_SIZE \ | ||
# --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ | ||
# --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ | ||
# --quantize_weights=1 \ | ||
# --quantize_type=int8_per_channel \ | ||
# --quantize_kv_cache=1 \ | ||
# --scenario Offline \ | ||
# --input_mode tokenized \ | ||
# --output_mode tokenized \ | ||
# --mlperf_conf $BASEDIR/mlperf.conf \ | ||
# --user_conf ${USER_CONFIG} \ | ||
# --audit_conf no_audit \ | ||
# --total_sample_count ${TOTAL_SAMPLE_COUNT} \ | ||
# --dataset_path ${DATASET_PATH} \ | ||
# --output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log | ||
|
||
python -m mlperf.evaluate_accuracy \ | ||
--checkpoint-path ${TOKENIZER_PATH} \ | ||
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ | ||
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log | ||
popd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
import argparse | ||
from transformers import AutoTokenizer | ||
import nltk | ||
import evaluate | ||
import numpy as np | ||
import pandas as pd | ||
import json | ||
import re | ||
|
||
import logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
log = logging.getLogger("evaluate_accuracy.py") | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--checkpoint-path", | ||
required=True, | ||
help="Path to Mixtral-8x7b-Instruct checkpoint", | ||
) | ||
parser.add_argument( | ||
"--mlperf-accuracy-file", | ||
required=True, | ||
help="path to mlperf_log_accuracy.json", | ||
) | ||
parser.add_argument( | ||
"--dataset-file", | ||
required=True, | ||
help="path to processed validation dataset", | ||
) | ||
parser.add_argument( | ||
"--n_workers", | ||
default=2, | ||
type=int, | ||
help="Number of workers used for the MBXP evaluation", | ||
) | ||
parser.add_argument("--verbose", action="store_true", help="verbose messages") | ||
parser.add_argument( | ||
"--dtype", | ||
default="int64", | ||
help="dtype of the accuracy log", | ||
choices=["int32", "int64", "float"], | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def get_groundtruth(processed_dataset_file): | ||
data = pd.read_pickle(processed_dataset_file) | ||
return data | ||
|
||
|
||
# Functions for evaluating GSM8K | ||
def find_numbers(x: str) -> list[str]: | ||
"""Finds all numbers in a string.""" | ||
# Search for number, possibly negative (hyphen), with thousand separators | ||
# (comma), and with a decimal point (period inbetween digits). | ||
numbers = re.compile( | ||
r"-?[\d,]*\.?\d+", | ||
re.MULTILINE | re.DOTALL | re.IGNORECASE, | ||
).findall(x) | ||
return numbers | ||
|
||
|
||
def find_number(x: str, answer_delimiter: str = "The answer is") -> str: | ||
"""Finds the most relevant number in a string.""" | ||
# If model uses the answer delimiter, then select the first number following | ||
# that format. | ||
if answer_delimiter in x: | ||
answer = x.split(answer_delimiter)[-1] | ||
numbers = find_numbers(answer) | ||
if numbers: | ||
return numbers[0] | ||
|
||
# In general, select the last number in the string. | ||
numbers = find_numbers(x) | ||
if numbers: | ||
return numbers[-1] | ||
return "" | ||
|
||
|
||
def maybe_remove_comma(x: str) -> str: | ||
# Example: 5,600 -> 5600 | ||
return x.replace(",", "") | ||
|
||
|
||
def try_float(x: str): | ||
try: | ||
ret = float(x) | ||
except BaseException: | ||
ret = None | ||
return ret | ||
|
||
|
||
# Functions for evaluating OpenOrca | ||
|
||
|
||
def postprocess_text(preds, targets): | ||
preds = [pred.strip() for pred in preds] | ||
targets = [target.strip() for target in targets] | ||
|
||
# rougeLSum expects newline after each sentence | ||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] | ||
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] | ||
|
||
return preds, targets | ||
|
||
|
||
# Functions for MBXP | ||
|
||
|
||
def create_mbxp_dict(row, response): | ||
lang, entry_point = row["id"].split("_", 1) | ||
return { | ||
"lang": lang, | ||
"prompt": row["input"], | ||
"test_code": row["gt_output"], | ||
"entry_point": entry_point, | ||
"response": response, | ||
} | ||
|
||
|
||
def main(): | ||
|
||
args = get_args() | ||
dataset_path = args.dataset_file | ||
checkpoint_path = args.checkpoint_path | ||
metric = evaluate.load("rouge") | ||
nltk.download("punkt") | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
checkpoint_path, | ||
model_max_length=2048, | ||
padding_side="left", | ||
use_fast=False, | ||
) | ||
|
||
data = get_groundtruth(args.dataset_file) | ||
query_types, gt_outputs = data["dataset"], data["gt_output"] | ||
|
||
target_required_GSM8K = [] | ||
target_required_OpenOrca = [] | ||
results_MBXP = [] | ||
preds_token_GSM8K = [] | ||
preds_token_OpenOrca = [] | ||
preds_token_MBXP = [] | ||
|
||
eval_dtype = np.int64 | ||
if args.dtype == "int32": | ||
eval_dtype = np.int32 | ||
elif args.dtype == "float": | ||
eval_dtype = np.float32 | ||
|
||
with open(args.mlperf_accuracy_file, "r") as f: | ||
results = json.load(f) | ||
|
||
seen = set() | ||
gen_tok_len = 0 | ||
gen_num = 0 | ||
for pred in results: | ||
gen_num += 1 | ||
qsl_idx = pred["qsl_idx"] | ||
if qsl_idx in seen: | ||
continue | ||
|
||
seen.add(qsl_idx) | ||
|
||
query_type = query_types.iloc[qsl_idx] | ||
if query_type == "GSM8K": | ||
target = gt_outputs.iloc[qsl_idx] | ||
target_required_GSM8K.append(target) | ||
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) | ||
gen_tok_len += len(pred) | ||
preds_token_GSM8K.append(pred) | ||
elif query_type == "OpenOrca": | ||
target = gt_outputs.iloc[qsl_idx] | ||
target_required_OpenOrca.append(target) | ||
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) | ||
preds_token_OpenOrca.append(pred) | ||
gen_tok_len += len(pred) | ||
else: | ||
target = data.iloc[qsl_idx] | ||
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) | ||
pred_str = tokenizer.decode(pred, skip_special_tokens=True) | ||
results_MBXP.append(create_mbxp_dict(target, pred_str)) | ||
gen_tok_len += len(pred) | ||
|
||
# OpenOrca metric | ||
preds_decoded_text = tokenizer.batch_decode( | ||
preds_token_OpenOrca, skip_special_tokens=True | ||
) | ||
|
||
preds, targets = postprocess_text( | ||
preds_decoded_text, target_required_OpenOrca | ||
) | ||
|
||
if preds: | ||
result = metric.compute( | ||
predictions=preds, | ||
references=targets, | ||
use_stemmer=True, | ||
use_aggregator=False, | ||
) | ||
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} | ||
prediction_lens = [len(pred) for pred in preds] | ||
|
||
else: | ||
result = {} | ||
prediction_lens = [] | ||
|
||
import ipdb; ipdb.set_trace() | ||
# GSM8K metric | ||
preds_decoded_text = tokenizer.batch_decode( | ||
preds_token_GSM8K, skip_special_tokens=True | ||
) | ||
pred_nums = [ | ||
maybe_remove_comma(find_number(pred_text.split("\nQ:")[0])) | ||
for pred_text in preds_decoded_text | ||
] | ||
gsm8k_total = len(target_required_GSM8K) | ||
correct = 0 | ||
for idx in range(len(target_required_GSM8K)): | ||
ref = try_float(target_required_GSM8K[idx]) | ||
tgt = try_float(pred_nums[idx]) | ||
if tgt is None: | ||
continue | ||
correct += ref == tgt | ||
|
||
result["gsm8k"] = 100.0 * correct / gsm8k_total | ||
|
||
# MBXP metric | ||
# from evaluate_mbxp import evaluate_mbxp | ||
|
||
# if results_MBXP: | ||
# result['mbxp'] = evaluate_mbxp(results_MBXP, args.n_workers) | ||
# else: | ||
# result['mbxp'] = 0 | ||
|
||
result = { | ||
**result, | ||
"gen_len": np.sum(prediction_lens), | ||
"gen_num": gen_num, | ||
"gen_tok_len": gen_tok_len, | ||
"tokens_per_sample": round(gen_tok_len / gen_num, 1), | ||
} | ||
|
||
print("\nResults\n") | ||
print(result) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env bash | ||
|
||
DATA_DISK_DIR=data | ||
|
||
mkdir -p $DATA_DISK_DIR | ||
|
||
pip install -U "huggingface_hub[cli]" | ||
pip install \ | ||
transformers \ | ||
nltk==3.8.1 \ | ||
evaluate==0.4.0 \ | ||
absl-py==1.4.0 \ | ||
rouge-score==0.1.2 \ | ||
sentencepiece==0.1.99 \ | ||
accelerate==0.21.0 | ||
|
||
# install loadgen | ||
pip install mlperf-loadgen | ||
|
||
|
||
pushd $DATA_DISK_DIR | ||
|
||
# model weights | ||
gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive | ||
# NOTE: uncomment one so you dont download too much weights to your box | ||
# gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive | ||
|
||
# Get mixtral data | ||
wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl | ||
mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl | ||
wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl | ||
mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl | ||
|
||
# Get llama70b data | ||
gcloud storage cp \ | ||
gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ | ||
processed-calibration-data.pkl | ||
gcloud storage cp \ | ||
gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \ | ||
processed-data.pkl | ||
popd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.