Skip to content

Commit

Permalink
add other scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jul 23, 2024
1 parent 9e7db0d commit 8c9ad0e
Show file tree
Hide file tree
Showing 7 changed files with 455 additions and 2 deletions.
56 changes: 56 additions & 0 deletions mlperf/accuracy_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env bash
me=$(basename "$0")

BASEDIR=mlperf
API_URL=0.0.0.0:9000
USER_CONFIG=$BASEDIR/user.conf
DATA_DISK_DIR=$BASEDIR/data
TOTAL_SAMPLE_COUNT=1000
DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl

# HF model id
TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1"
LOADGEN_RUN_TYPE=offline-performance
OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID}
OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP}

mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR}

OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json

CACHE_LENGTH=1024
INPUT_SIZE=512
OUTPUT_SIZE=512
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/

LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
# makes subsequent runs faster
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
export LIBTPU_INIT_ARGS

pushd ..
# python -m mlperf.offline_mode \
# --model_name=mixtral \
# --max_cache_length=$CACHE_LENGTH \
# --max_decode_length=$OUTPUT_SIZE \
# --context_length=$INPUT_SIZE \
# --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \
# --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \
# --quantize_weights=1 \
# --quantize_type=int8_per_channel \
# --quantize_kv_cache=1 \
# --scenario Offline \
# --input_mode tokenized \
# --output_mode tokenized \
# --mlperf_conf $BASEDIR/mlperf.conf \
# --user_conf ${USER_CONFIG} \
# --audit_conf no_audit \
# --total_sample_count ${TOTAL_SAMPLE_COUNT} \
# --dataset_path ${DATASET_PATH} \
# --output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log

python -m mlperf.evaluate_accuracy \
--checkpoint-path ${TOKENIZER_PATH} \
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log
popd
252 changes: 252 additions & 0 deletions mlperf/evaluate_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import argparse
from transformers import AutoTokenizer
import nltk
import evaluate
import numpy as np
import pandas as pd
import json
import re

import logging
logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger("evaluate_accuracy.py")

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint-path",
required=True,
help="Path to Mixtral-8x7b-Instruct checkpoint",
)
parser.add_argument(
"--mlperf-accuracy-file",
required=True,
help="path to mlperf_log_accuracy.json",
)
parser.add_argument(
"--dataset-file",
required=True,
help="path to processed validation dataset",
)
parser.add_argument(
"--n_workers",
default=2,
type=int,
help="Number of workers used for the MBXP evaluation",
)
parser.add_argument("--verbose", action="store_true", help="verbose messages")
parser.add_argument(
"--dtype",
default="int64",
help="dtype of the accuracy log",
choices=["int32", "int64", "float"],
)
args = parser.parse_args()
return args


def get_groundtruth(processed_dataset_file):
data = pd.read_pickle(processed_dataset_file)
return data


# Functions for evaluating GSM8K
def find_numbers(x: str) -> list[str]:
"""Finds all numbers in a string."""
# Search for number, possibly negative (hyphen), with thousand separators
# (comma), and with a decimal point (period inbetween digits).
numbers = re.compile(
r"-?[\d,]*\.?\d+",
re.MULTILINE | re.DOTALL | re.IGNORECASE,
).findall(x)
return numbers


def find_number(x: str, answer_delimiter: str = "The answer is") -> str:
"""Finds the most relevant number in a string."""
# If model uses the answer delimiter, then select the first number following
# that format.
if answer_delimiter in x:
answer = x.split(answer_delimiter)[-1]
numbers = find_numbers(answer)
if numbers:
return numbers[0]

# In general, select the last number in the string.
numbers = find_numbers(x)
if numbers:
return numbers[-1]
return ""


def maybe_remove_comma(x: str) -> str:
# Example: 5,600 -> 5600
return x.replace(",", "")


def try_float(x: str):
try:
ret = float(x)
except BaseException:
ret = None
return ret


# Functions for evaluating OpenOrca


def postprocess_text(preds, targets):
preds = [pred.strip() for pred in preds]
targets = [target.strip() for target in targets]

# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]

return preds, targets


# Functions for MBXP


def create_mbxp_dict(row, response):
lang, entry_point = row["id"].split("_", 1)
return {
"lang": lang,
"prompt": row["input"],
"test_code": row["gt_output"],
"entry_point": entry_point,
"response": response,
}


def main():

args = get_args()
dataset_path = args.dataset_file
checkpoint_path = args.checkpoint_path
metric = evaluate.load("rouge")
nltk.download("punkt")

tokenizer = AutoTokenizer.from_pretrained(
checkpoint_path,
model_max_length=2048,
padding_side="left",
use_fast=False,
)

data = get_groundtruth(args.dataset_file)
query_types, gt_outputs = data["dataset"], data["gt_output"]

target_required_GSM8K = []
target_required_OpenOrca = []
results_MBXP = []
preds_token_GSM8K = []
preds_token_OpenOrca = []
preds_token_MBXP = []

eval_dtype = np.int64
if args.dtype == "int32":
eval_dtype = np.int32
elif args.dtype == "float":
eval_dtype = np.float32

with open(args.mlperf_accuracy_file, "r") as f:
results = json.load(f)

seen = set()
gen_tok_len = 0
gen_num = 0
for pred in results:
gen_num += 1
qsl_idx = pred["qsl_idx"]
if qsl_idx in seen:
continue

seen.add(qsl_idx)

query_type = query_types.iloc[qsl_idx]
if query_type == "GSM8K":
target = gt_outputs.iloc[qsl_idx]
target_required_GSM8K.append(target)
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
gen_tok_len += len(pred)
preds_token_GSM8K.append(pred)
elif query_type == "OpenOrca":
target = gt_outputs.iloc[qsl_idx]
target_required_OpenOrca.append(target)
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
preds_token_OpenOrca.append(pred)
gen_tok_len += len(pred)
else:
target = data.iloc[qsl_idx]
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
pred_str = tokenizer.decode(pred, skip_special_tokens=True)
results_MBXP.append(create_mbxp_dict(target, pred_str))
gen_tok_len += len(pred)

# OpenOrca metric
preds_decoded_text = tokenizer.batch_decode(
preds_token_OpenOrca, skip_special_tokens=True
)

preds, targets = postprocess_text(
preds_decoded_text, target_required_OpenOrca
)

if preds:
result = metric.compute(
predictions=preds,
references=targets,
use_stemmer=True,
use_aggregator=False,
)
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
prediction_lens = [len(pred) for pred in preds]

else:
result = {}
prediction_lens = []

import ipdb; ipdb.set_trace()
# GSM8K metric
preds_decoded_text = tokenizer.batch_decode(
preds_token_GSM8K, skip_special_tokens=True
)
pred_nums = [
maybe_remove_comma(find_number(pred_text.split("\nQ:")[0]))
for pred_text in preds_decoded_text
]
gsm8k_total = len(target_required_GSM8K)
correct = 0
for idx in range(len(target_required_GSM8K)):
ref = try_float(target_required_GSM8K[idx])
tgt = try_float(pred_nums[idx])
if tgt is None:
continue
correct += ref == tgt

result["gsm8k"] = 100.0 * correct / gsm8k_total

# MBXP metric
# from evaluate_mbxp import evaluate_mbxp

# if results_MBXP:
# result['mbxp'] = evaluate_mbxp(results_MBXP, args.n_workers)
# else:
# result['mbxp'] = 0

result = {
**result,
"gen_len": np.sum(prediction_lens),
"gen_num": gen_num,
"gen_tok_len": gen_tok_len,
"tokens_per_sample": round(gen_tok_len / gen_num, 1),
}

print("\nResults\n")
print(result)


if __name__ == "__main__":
main()
41 changes: 41 additions & 0 deletions mlperf/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env bash

DATA_DISK_DIR=data

mkdir -p $DATA_DISK_DIR

pip install -U "huggingface_hub[cli]"
pip install \
transformers \
nltk==3.8.1 \
evaluate==0.4.0 \
absl-py==1.4.0 \
rouge-score==0.1.2 \
sentencepiece==0.1.99 \
accelerate==0.21.0

# install loadgen
pip install mlperf-loadgen


pushd $DATA_DISK_DIR

# model weights
gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive
# NOTE: uncomment one so you dont download too much weights to your box
# gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive

# Get mixtral data
wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl
mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl
wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl
mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl

# Get llama70b data
gcloud storage cp \
gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \
processed-calibration-data.pkl
gcloud storage cp \
gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \
processed-data.pkl
popd
2 changes: 1 addition & 1 deletion mlperf/mixtral_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ me=$(basename "$0")
BASEDIR=mlperf
USER_CONFIG=$BASEDIR/user.conf
DATA_DISK_DIR=$BASEDIR/data
TOTAL_SAMPLE_COUNT=1000
TOTAL_SAMPLE_COUNT=900

# HF model id
TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1"
Expand Down
Loading

0 comments on commit 8c9ad0e

Please sign in to comment.