Skip to content

Commit ee01072

Browse files
author
maxtext authors
committed
Merge pull request #1059 from AI-Hypercomputer:patemotter_acc_eval
PiperOrigin-RevId: 700018264
2 parents ea717d9 + 7eca5f9 commit ee01072

File tree

2 files changed

+214
-14
lines changed

2 files changed

+214
-14
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import evaluate
17+
import json
18+
import nltk
19+
import numpy as np
20+
import pandas as pd
21+
import tqdm
22+
23+
from multiprocessing import Pool, cpu_count
24+
from functools import partial
25+
from transformers import AutoTokenizer
26+
from typing import List, Dict, Tuple
27+
28+
29+
def split_data(preds: List[str], refs: List[str], num_chunks: int) -> List[Tuple[List[str], List[str]]]:
30+
"""Split predictions and references into roughly equal chunks"""
31+
chunk_size = len(preds) // num_chunks + (1 if len(preds) % num_chunks else 0)
32+
chunks = []
33+
34+
for i in range(0, len(preds), chunk_size):
35+
chunk_preds = preds[i : i + chunk_size]
36+
chunk_refs = refs[i : i + chunk_size]
37+
chunks.append((chunk_preds, chunk_refs))
38+
39+
return chunks
40+
41+
42+
def compute_rouge_chunk(chunk: Tuple[List[str], List[str]], metric) -> Dict:
43+
"""Compute ROUGE scores for a chunk of data"""
44+
preds, refs = chunk
45+
return metric.compute(predictions=preds, references=refs, use_stemmer=True, use_aggregator=False)
46+
47+
48+
def aggregate_rouge_scores(chunk_results: List[Dict]) -> Dict:
49+
"""Aggregate ROUGE scores from chunks"""
50+
# Concatenate all scores
51+
all_scores = {}
52+
for scores in chunk_results:
53+
for metric, values in scores.items():
54+
if metric not in all_scores:
55+
all_scores[metric] = []
56+
all_scores[metric].extend(values)
57+
58+
return {metric: round(np.mean(values) * 100, 4) for metric, values in all_scores.items()}
59+
60+
61+
def get_args():
62+
parser = argparse.ArgumentParser()
63+
parser.add_argument("--checkpoint-path", required=True, help="Path to Llama2-70b-hf-chat checkpoint")
64+
parser.add_argument("--mlperf-accuracy-file", required=True, help="path to mlperf_log_accuracy.json")
65+
parser.add_argument("--dataset-file", required=True, help="path to processed openorca validation set")
66+
parser.add_argument("--verbose", action="store_true", help="verbose messages")
67+
parser.add_argument("--dtype", default="int64", help="dtype of the accuracy log", choices=["int32", "int64", "float"])
68+
parser.add_argument("--num-workers", type=int, default=None, help="Number of worker processes (default: CPU count)")
69+
args = parser.parse_args()
70+
return args
71+
72+
73+
def get_groundtruth(processed_dataset_file):
74+
data = pd.read_pickle(processed_dataset_file)
75+
return data["output"]
76+
77+
78+
def process_batch(batch, tokenizer, eval_dtype):
79+
"""Process a batch of predictions"""
80+
preds_token_ids = []
81+
seen = set()
82+
gen_tok_len = 0
83+
target_indices = []
84+
85+
for pred in batch:
86+
qsl_idx = pred["qsl_idx"]
87+
if qsl_idx in seen:
88+
continue
89+
90+
seen.add(qsl_idx)
91+
target_indices.append(qsl_idx)
92+
93+
pred_data = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
94+
if pred_data[0] > 32000 or pred_data[0] < 0:
95+
pred_data = np.concatenate([[1], pred_data[1:]])
96+
97+
gen_tok_len += len(pred_data)
98+
preds_token_ids.append(pred_data)
99+
100+
# Batch decode predictions
101+
preds_decoded = tokenizer.batch_decode(preds_token_ids, skip_special_tokens=True)
102+
return preds_decoded, target_indices, gen_tok_len
103+
104+
105+
def postprocess_text(pred, target):
106+
"""Process a single prediction-target pair"""
107+
pred = pred.strip()
108+
target = target.strip()
109+
110+
# rougeLSum expects newline after each sentence
111+
pred = "\n".join(nltk.sent_tokenize(pred))
112+
target = "\n".join(nltk.sent_tokenize(target))
113+
114+
return pred, target
115+
116+
117+
def chunk_list(lst, n):
118+
"""Split list into n roughly equal chunks"""
119+
chunk_size = len(lst) // n + (1 if len(lst) % n else 0)
120+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
121+
122+
123+
def main():
124+
args = get_args()
125+
num_workers = args.num_workers or cpu_count()
126+
print(f"Using {num_workers} worker processes")
127+
128+
print(f"Loading checkpoint from {args.checkpoint_path}")
129+
tokenizer = AutoTokenizer.from_pretrained(
130+
args.checkpoint_path,
131+
model_max_length=2048,
132+
padding_side="left",
133+
use_fast=False,
134+
)
135+
136+
metric = evaluate.load("rouge")
137+
nltk.download("punkt", quiet=True)
138+
139+
print(f"Getting groundtruth from {args.dataset_file}")
140+
targets = get_groundtruth(args.dataset_file)
141+
142+
eval_dtype = {"int32": np.int32, "int64": np.int64, "float": np.float32}[args.dtype]
143+
144+
print(f"Loading accuracy log from {args.mlperf_accuracy_file}")
145+
with open(args.mlperf_accuracy_file, "r") as f:
146+
results = json.load(f)
147+
148+
# Split results into chunks for parallel processing
149+
result_chunks = chunk_list(results, num_workers)
150+
151+
# Process predictions in parallel
152+
process_func = partial(process_batch, tokenizer=tokenizer, eval_dtype=eval_dtype)
153+
total_gen_tok_len = 0
154+
all_preds = []
155+
all_target_indices = []
156+
157+
print("Processing predictions...")
158+
with Pool(num_workers) as pool:
159+
for preds, target_indices, gen_tok_len in tqdm.tqdm(pool.imap(process_func, result_chunks), total=len(result_chunks)):
160+
all_preds.extend(preds)
161+
all_target_indices.extend(target_indices)
162+
total_gen_tok_len += gen_tok_len
163+
164+
target_required = [targets[idx] for idx in all_target_indices]
165+
166+
# Parallel postprocessing of texts
167+
print("Post-processing texts...")
168+
with Pool(num_workers) as pool:
169+
processed_pairs = list(tqdm.tqdm(pool.starmap(postprocess_text, zip(all_preds, target_required)), total=len(all_preds)))
170+
preds, refs = zip(*processed_pairs)
171+
172+
# Split data into chunks for parallel ROUGE computation
173+
print("Computing ROUGE scores...")
174+
data_chunks = split_data(preds, refs, num_workers)
175+
with Pool(num_workers) as pool:
176+
chunk_results = list(
177+
tqdm.tqdm(pool.imap(partial(compute_rouge_chunk, metric=metric), data_chunks), total=len(data_chunks))
178+
)
179+
rouge_scores = aggregate_rouge_scores(chunk_results)
180+
181+
prediction_lens = [len(pred) for pred in preds]
182+
gen_num = len(preds)
183+
result = {
184+
**rouge_scores,
185+
"gen_len": np.sum(prediction_lens),
186+
"gen_num": gen_num,
187+
"gen_tok_len": total_gen_tok_len,
188+
"tokens_per_sample": round(total_gen_tok_len / gen_num, 1),
189+
}
190+
191+
print("\nResults\n")
192+
print(result)
193+
194+
195+
if __name__ == "__main__":
196+
main()

MaxText/inference_mlperf/llama_offline_run.sh

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ enable_profiler=false
1414
performance=true
1515
audit=false
1616
accuracy=false
17+
fast_eval=false
1718

18-
19-
while getopts "ntspdar:" opt
19+
while getopts "ntspdarfr:" opt
2020
do
2121
case "$opt" in
2222
n ) dry_run=true ;;
23-
t ) test_run=true ;;
23+
t ) test_run=true ;;
2424
s ) skip_warmup=true ;;
2525
p ) enable_profiler=true ;;
2626
d ) audit=true ;;
2727
a ) accuracy=true ;;
28+
f ) fast_eval=true ;;
2829
r ) run_name="$OPTARG" ;;
2930
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
3031
esac
@@ -101,13 +102,11 @@ export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
101102
export LIBTPU_INIT_ARGS
102103

103104
run_loadgen() {
104-
105105
OUTPUT_LOG_ID=llama70b-${run_name}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TYPE}_${LOADGEN_RUN_TIMESTAMP}
106106
OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID}
107107
mkdir -p ${OUTPUT_LOG_DIR} && cp ${USER_CONFIG} ${OUTPUT_LOG_DIR}
108108
OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json
109109

110-
111110
echo "LOADGEN_RUN_TIMESTAMP: ${LOADGEN_RUN_TIMESTAMP}"
112111
echo "DATASET_PATH: ${DATASET_PATH}"
113112
echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}"
@@ -118,19 +117,18 @@ run_loadgen() {
118117

119118
${cmd} python -m offline_mode \
120119
--mlperf_test_mode=${TEST_MODE} \
121-
--input_mode tokenized \
120+
--input_mode tokenized \
122121
--output_mode tokenized \
123-
--mlperf_conf $BASEDIR/mlperf.conf \
124-
--user_conf ${USER_CONFIG} \
125-
--audit_conf ${AUDIT_CONF} \
126-
--total_sample_count ${TOTAL_SAMPLE_COUNT} \
127-
--dataset_path ${DATASET_PATH} \
122+
--mlperf_conf $BASEDIR/mlperf.conf \
123+
--user_conf ${USER_CONFIG} \
124+
--audit_conf ${AUDIT_CONF} \
125+
--total_sample_count ${TOTAL_SAMPLE_COUNT} \
126+
--dataset_path ${DATASET_PATH} \
128127
--prefill_lengths_and_batch_sizes ${BATCH_AND_PREFILL_LEN} \
129128
--maxengine_args "${MAXENGINE_ARGS}" \
130-
--output_log_dir ${OUTPUT_LOG_DIR} \
129+
--output_log_dir ${OUTPUT_LOG_DIR} \
131130
--tok_outlen_multiplier ${TOK_OUTLEN_MULTIPLIER} \
132131
${SKIP_WARMUP_OPTION} ${PROFILER_OPTION} 2>&1 | tee ${OUTPUT_LOG_DIR}/${LOADGEN_RUN_TYPE}_log.log
133-
134132
}
135133

136134
run_loadgen_performance () {
@@ -155,7 +153,13 @@ run_loadgen_accuracy () {
155153

156154
# Eval Run
157155
if [ -e ${OUTPUT_ACCURACY_JSON_PATH} ]; then
158-
${CMD} python3 evaluate-accuracy.py \
156+
if [ "${FAST_EVAL:-false}" = "true" ] || "$fast_eval"; then
157+
EVAL_SCRIPT="evaluate-accuracy-fast.py"
158+
else
159+
EVAL_SCRIPT="evaluate-accuracy.py"
160+
fi
161+
162+
${CMD} python3 ${EVAL_SCRIPT} \
159163
--checkpoint-path meta-llama/Llama-2-70b-chat-hf \
160164
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \
161165
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log

0 commit comments

Comments
 (0)