Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMLU Benchmarks #1163

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions MaxText/benchmarks/mmlu/mmlu_categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Dictionary mapping each subject to its respective subcategories.
# Subcategories represent more specific topics which are later grouped into broader categories.

subcategories = {
# STEM Subjects
"abstract_algebra": ["math"],
"college_mathematics": ["math"],
"elementary_mathematics": ["math"],
"high_school_mathematics": ["math"],
"high_school_statistics": ["math"],
"college_chemistry": ["chemistry"],
"high_school_chemistry": ["chemistry"],
"college_physics": ["physics"],
"astronomy": ["physics"],
"conceptual_physics": ["physics"],
"high_school_physics": ["physics"],
"electrical_engineering": ["engineering"],
"college_biology": ["biology"],
"high_school_biology": ["biology"],
"college_computer_science": ["computer science"],
"high_school_computer_science": ["computer science"],
"computer_security": ["computer science"],
"machine_learning": ["computer science"],
# Humanities
"formal_logic": ["philosophy"],
"logical_fallacies": ["philosophy"],
"moral_disputes": ["philosophy"],
"moral_scenarios": ["philosophy"],
"philosophy": ["philosophy"],
"world_religions": ["philosophy"],
"high_school_european_history": ["history"],
"high_school_us_history": ["history"],
"high_school_world_history": ["history"],
"prehistory": ["history"],
"international_law": ["law"],
"jurisprudence": ["law"],
"professional_law": ["law"],
# Social Sciences
"econometrics": ["economics"],
"high_school_macroeconomics": ["economics"],
"high_school_microeconomics": ["economics"],
"public_relations": ["politics"],
"security_studies": ["politics"],
"us_foreign_policy": ["politics"],
"high_school_government_and_politics": ["politics"],
"high_school_psychology": ["psychology"],
"professional_psychology": ["psychology"],
"sociology": ["culture"],
"human_sexuality": ["culture"],
"high_school_geography": ["geography"],
# Other (Business, Health, Miscellaneous)
"business_ethics": ["business"],
"management": ["business"],
"marketing": ["business"],
"professional_accounting": ["other"],
"global_facts": ["other"],
"miscellaneous": ["other"],
"anatomy": ["health"],
"clinical_knowledge": ["health"],
"college_medicine": ["health"],
"human_aging": ["health"],
"medical_genetics": ["health"],
"nutrition": ["health"],
"professional_medicine": ["health"],
"virology": ["health"],
}

# Dictionary mapping broader categories to their corresponding subcategories.
# Categories help in grouping subcategories into broader disciplines for analysis.

categories = {
"STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
"humanities": ["history", "philosophy", "law"],
"social sciences": ["politics", "culture", "economics", "geography", "psychology"],
"other (business, health, misc.)": ["other", "business", "health"],
}
210 changes: 210 additions & 0 deletions MaxText/benchmarks/mmlu/mmlu_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This is a simple script for MMLU benchmark for a trained checkpoint.

To get optimal performace the prompt template needs to be adjusted (e.g. CoT or 5-shot prompt) per model.


To run the MMLU benchmark:
python3 MaxText/benchmarks/mmlu/mmlu_eval.py MaxText/configs/base.yml \
tokenizer_path=assets/tokenizer_llama3.tiktoken \
load_parameters_path=check_point_path model_name=llama3.1-8b \
max_prefill_predict_length=1024 max_target_length=2048 ici_tensor_parallelism=4 per_device_batch_size=1
"""

import collections
import re
import sys

from absl import flags
import datasets
import jax
import max_logging
import max_utils
import maxengine
from mmlu_categories import categories
from mmlu_categories import subcategories
import pyconfig
from tqdm import tqdm


ASCII_UPPERCASE_A = ord("A") # ASCII value for uppercase 'A'

DEFAULT_PROMPT_TEMPLATE = """The following are multiple choice questions (with answers) about {subject}.

{question}
{choices}
Answer:"""


_PROMPT_TEMPLATE = flags.DEFINE_string(
"prompt_template",
default=DEFAULT_PROMPT_TEMPLATE,
help="prompt template",
)


def construct_prompt(subject, question, choices):
subject = subject.replace("_", " ")
choices_text = "\n".join(f"{chr(ASCII_UPPERCASE_A + idx)}. {choice}" for idx, choice in enumerate(choices))
prompt = _PROMPT_TEMPLATE.value.format(subject=subject, question=question, choices=choices_text)
return prompt


def parse_answer(output):
match = re.search(r"Answer:\s*([A-D])|(?:The answer is)\s*([A-D])", output, re.IGNORECASE)
predicted_answer = match.group(1) or match.group(2) if match else None
return predicted_answer


def main(config):
engine = maxengine.MaxEngine(config)
params = engine.load_params()

metadata = engine.get_tokenizer()
tokenizer = engine.build_tokenizer(metadata)

max_prefill_predict_length = getattr(config, "max_prefill_predict_length", 1024)
max_target_length = getattr(config, "max_target_length", 2048)

# Initialize counters for overall and per-subject accuracies
correct_count = 0
total_count = 0
subject_correct = collections.defaultdict(int)
subject_total = collections.defaultdict(int)
subcat_correct = collections.defaultdict(int)
subcat_total = collections.defaultdict(int)

mmlu_test_ds = datasets.load_dataset("lighteval/mmlu", "all", split="test")
for idx, example in enumerate(tqdm(mmlu_test_ds, desc="Evaluating MMLU dataset")):
subject = example["subject"]
question = example["question"]
choices = example["choices"]
label = example["answer"]
prompt = construct_prompt(subject, question, choices)

# Tokenize the input
tokens, true_length = tokenizer.encode(prompt, is_bos=True, prefill_lengths=[max_prefill_predict_length])
if true_length > max_prefill_predict_length:
max_logging.log(
f"Warning: Prompt length {true_length} exceeds max prefill length" f" {max_prefill_predict_length}. Truncating."
)
tokens = tokens[:max_prefill_predict_length]
true_length = max_prefill_predict_length
assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet"

# Perform prefill
prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
slot = 0

# Initialize decode state
decode_state = engine.init_decode_state()
decode_state = engine.insert(prefill_result, decode_state, slot=slot)

steps = range(max_prefill_predict_length, max_target_length)
sampled_tokens = [first_token.get_result_at_slot(slot).tokens.item()]

predicted_answer = ""

for _ in steps:
# Decode generated tokens so far
output = tokenizer.decode(sampled_tokens)
predicted_answer = parse_answer(prompt + output)
if predicted_answer:
break

# Generate next token
decode_state, sampled_token = engine.generate(params, decode_state)
sampled_tokens.append(sampled_token.get_result_at_slot(slot).tokens.item())
if sampled_tokens[-1] == tokenizer.eos_id:
break

if not predicted_answer:
max_logging.log("Could not extract an answer from the model's output for example" f" {total_count + 1}")
elif predicted_answer not in {chr(ASCII_UPPERCASE_A + idx) for idx in range(len(choices))}:
max_logging.log(f"Invalid or missing predicted answer for subject '{subject}' in example {total_count + 1}")

# Convert the label index to the corresponding letter
correct_answer = chr(65 + label)

# Update accuracy for overall and per-subject
if predicted_answer == correct_answer:
correct_count += 1
subject_correct[subject] += 1
total_count += 1
subject_total[subject] += 1

if idx % 50 == 0:
max_logging.log(f" Accuracy: {correct_count / total_count:.4f}")

# Final accuracy
if total_count > 0:
accuracy = correct_count / total_count
max_logging.log(f"\nFinal accuracy on MMLU dataset: {accuracy:.4f}")
else:
max_logging.log("No valid predictions were made.")

# Calculate subject accuracies
subject_acc = {subject: subject_correct[subject] / subject_total[subject] for subject in subject_total}

# Map subject accuracies to subcategories
for subject in subject_acc:
if subject in subcategories:
subcat_labels = subcategories[subject]
for subcat_label in subcat_labels:
subcat_correct[subcat_label] += subject_correct[subject]
subcat_total[subcat_label] += subject_total[subject]
else:
max_logging.log(f"Warning: Subject '{subject}' not found in subcategories.")

# Calculate subcategory accuracies
max_logging.log("\nSubcategory Accuracies:")
for subcat_label in subcat_total:
acc = subcat_correct[subcat_label] / subcat_total[subcat_label]
max_logging.log(f"Accuracy for subcategory '{subcat_label}': {acc:.4f}")

# Calculate and print category accuracies
cat_correct = collections.defaultdict(int)
cat_total = collections.defaultdict(int)

for category_name, subcat_labels in categories.items():
for subcat_label in subcat_labels:
cat_correct[category_name] += subcat_correct[subcat_label]
cat_total[category_name] += subcat_total[subcat_label]

max_logging.log("\nCategory Accuracies:")
for category_name in cat_total:
if cat_total[category_name] > 0:
acc = cat_correct[category_name] / cat_total[category_name]
max_logging.log(f"Accuracy for category '{category_name}': {acc:.4f}")
else:
max_logging.log(f"Accuracy for category '{category_name}': No data available.")


def validate_config(config):
assert not config.load_full_state_path, (
"Decode doesn't operate on full states! Convert to parameter checkpoint"
" first. Using generate_param_only_checkpoint."
)


if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
flags.FLAGS(sys.argv)
pyconfig.initialize(sys.argv)
cfg = pyconfig.config
validate_config(cfg)
max_utils.print_system_information()
main(cfg)
Loading