Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return evaluation results to callers #71

Merged
merged 2 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions dalm/eval/eval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
PreTrainedTokenizer,
)

from dalm.eval.eval_results import EvalResults
from dalm.eval.utils import (
calc_eval_results,
construct_search_index,
evaluate_retriever_on_batch,
get_passage_embeddings,
Expand Down Expand Up @@ -180,7 +182,7 @@ def evaluate_rag(
top_k: int = 10,
evaluate_generator: bool = True,
retriever_is_autoregressive: bool = False,
) -> None:
) -> EvalResults:
"""Runs rag evaluation. See `dalm eval-rag --help for details on params"""
test_dataset = load_dataset(dataset_or_path)
selected_torch_dtype: Final[torch.dtype] = torch.float16 if torch_dtype == "float16" else torch.bfloat16
Expand Down Expand Up @@ -254,8 +256,9 @@ def evaluate_rag(
generated_answers_for_eval.extend(batch_answers)

if not evaluate_generator:
print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
return
eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
print_eval_results(eval_results)
return eval_results

# TODO: imperative style code, refactor in future but works for now
# If there are any leftover batches to query
Expand All @@ -275,9 +278,11 @@ def evaluate_rag(
if generated_answer_string == answer:
total_em_hit += 1

print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
print_eval_results(eval_results)
print("Generator evaluation:")
print("Exact match:", total_em_hit / len(processed_datasets))
return eval_results


def main() -> None:
Expand Down
8 changes: 8 additions & 0 deletions dalm/eval/eval_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel


class EvalResults(BaseModel):
total_examples: int
recall: float
precision: float
hit_rate: float
9 changes: 7 additions & 2 deletions dalm/eval/eval_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from datasets import Dataset
from torch.utils.data import DataLoader
from dalm.eval.eval_results import EvalResults


from dalm.eval.utils import (
construct_search_index,
Expand All @@ -18,6 +20,7 @@
get_passage_embeddings,
evaluate_retriever_on_batch,
print_eval_results,
calc_eval_results,
)
from dalm.models.retriever_only_base_model import AutoModelForSentenceEmbedding
from dalm.utils import load_dataset
Expand Down Expand Up @@ -112,7 +115,7 @@ def evaluate_retriever(
torch_dtype: Literal["float16", "bfloat16"] = "float16",
top_k: int = 10,
is_autoregressive: bool = False,
) -> None:
) -> EvalResults:
"""Runs rag evaluation. See `dalm eval-retriever --help for details on params"""
test_dataset = load_dataset(dataset_or_path)
selected_torch_dtype: Final[torch.dtype] = torch.float16 if torch_dtype == "float16" else torch.bfloat16
Expand Down Expand Up @@ -170,7 +173,9 @@ def evaluate_retriever(
batch_recall.extend(_batch_recall)
total_hit += _total_hit

print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
print_eval_results(eval_results)
return eval_results


def main() -> None:
Expand Down
18 changes: 13 additions & 5 deletions dalm/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, default_data_collator

from dalm.eval.eval_results import EvalResults

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -270,18 +272,24 @@ def evaluate_retriever_on_batch(
return batch_precision, batch_recall, total_hit, top_passages


def print_eval_results(
def calc_eval_results(
total_examples: int,
precisions: list[float],
recalls: list[float],
total_hit: int,
) -> None:
) -> EvalResults:
precision = sum(precisions) / total_examples
recall = sum(recalls) / total_examples
hit_rate = total_hit / float(total_examples)

return EvalResults(total_examples=total_examples, recall=recall, precision=precision, hit_rate=hit_rate)


def print_eval_results(
eval_results: EvalResults,
) -> None:
logger.info("Retriever results:")
logger.info(f"Recall: {recall}")
logger.info(f"Precision: {precision}")
logger.info(f"Hit Rate: {hit_rate}")
logger.info(f"Recall: {eval_results.recall}")
logger.info(f"Precision: {eval_results.precision}")
logger.info(f"Hit Rate: {eval_results.hit_rate}")
logger.info("*************")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"diffusers",
"bitsandbytes",
"typer>=0.9.0,<1.0",
"pydantic==1.10.9", # Sync w/ other platform components
]

[project.scripts]
Expand Down