diff --git a/evaluation/evaluate_text_generation.py b/evaluation/evaluate_text_generation.py index 55f164c2f..192a64fae 100644 --- a/evaluation/evaluate_text_generation.py +++ b/evaluation/evaluate_text_generation.py @@ -1,4 +1,5 @@ import numpy as np +import torch.cuda from datasets import load_dataset from sacrebleu import corpus_bleu from transformers import pipeline @@ -11,31 +12,54 @@ def sacrebleu_score(hypotheses, references): return corpus_bleu(hypotheses, [references]).score +def _process_data(dataset_name, split): + '''Function for extracting expected columns and create a dataset.''' + if dataset_name == "xsum": + hf_dataset = load_dataset(dataset_name, "3.0.0", split=split) + dataset = KeyValueDataset.from_huggingface( + hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["document", "summary"] + ) + return dataset + elif dataset_name == "cnn_dailymail": + hf_dataset = load_dataset(dataset_name,"3.0.0", split=split) + dataset = KeyValueDataset.from_huggingface( + hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["article", "highlights"] + ) + return dataset + elif dataset_name == "big_patent": + hf_dataset = load_dataset(dataset_name, split) + dataset = KeyValueDataset.from_huggingface( + hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["description", "abstract"] + ) + return dataset + elif dataset_name == "billsum": + hf_dataset = load_dataset(dataset_name, split) + dataset = KeyValueDataset.from_huggingface( + hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["text", "summary"] + ) + + def evaluate( - operation, evaluate_filter, model_name, dataset_name, split="test[:20%]" -): + operation, evaluate_filter, model_name, + dataset_name, split="test[:20%]", is_cuda=torch.cuda.is_available()): # load model - if model_name is None: - model_name = "sshleifer/distilbart-xsum-12-6" + if model_name is None: model_name = "sshleifer/distilbart-xsum-12-6" # default model # load test set - if dataset_name is None: - dataset_name = "xsum" + if dataset_name is None: dataset_name = "xsum" # default dataset print( f"Loading <{dataset_name}> dataset to evaluate <{model_name}> model." ) - hf_dataset = ( - load_dataset(dataset_name, "3.0.0", split=split) - if dataset_name == "xsum" - else load_dataset(dataset_name, split=split) - ) - - dataset = KeyValueDataset.from_huggingface( - hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["document", "summary"] - ) summarization_pipeline = pipeline( - "summarization", model=model_name, tokenizer=model_name + "summarization", model=model_name, tokenizer=model_name, device=0 if is_cuda else -1) + #percent = f"[{split.split('[')[-1]}" if "[" in split else "" + #if dataset_name == "wikihow": split = "all[:1%]" # f"all{percent}" + + dataset = _process_data(dataset_name, split) + print( + f"Here is the performance of the model {model_name} on the {split} split of the {dataset_name} dataset" ) + print( f"Here is the performance of the model {model_name} on the {split} split of the {dataset_name} dataset" ) @@ -55,20 +79,16 @@ def evaluate( def filter_performance(dataset, summarization_pipeline, filter): + '''Evaluate performance on filtered dataset.''' print("Here is the performance of the model on the filtered set") filtered_dataset = dataset.apply_filter(filter, subfields=["document"]) return performance_on_dataset(filtered_dataset, summarization_pipeline) -""" -Evaluates performance on the original set -and on the perturbed set. -""" - - def transformation_performance( dataset, summarization_pipeline, transformation ): + '''Evaluates performance on the original set and on the perturbed set.''' performance = performance_on_dataset( dataset, summarization_pipeline ) # 15.989 BLEU @@ -83,11 +103,13 @@ def transformation_performance( def performance_on_dataset(dataset, summarization_pipeline): + '''Evaluate performance on a given dataset.''' references = [] raw_hypotheses = [] print(f"Length of Evaluation dataset is {len(dataset)}") - for example in dataset: + for i,example in enumerate(dataset): + print(i) article, gold_summary = example max_len = ( len(gold_summary.split(" ")) + 10 diff --git a/evaluation/leaderboard_wrapper.py b/evaluation/leaderboard_wrapper.py index eeaa4def1..9a69d1562 100644 --- a/evaluation/leaderboard_wrapper.py +++ b/evaluation/leaderboard_wrapper.py @@ -26,6 +26,11 @@ ("roberta-large-mnli", "multi_nli"), ("textattack/roberta-base-imdb", "imdb"), ], + "TEXT_TO_TEXT_GENERATION": [ + ("mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization", "cnn_dailymail"), + ("google/pegasus-billsum", "billsum"), + ("google/bigbird-pegasus-large-bigpatent", "big_patent"), + ], "TEXT_TAGGING": [], "DIALOGUE_ACT_TO_TEXT": [], "TABLE_TO_TEXT": [],