From fdc7cdf8d781f7b48be2702ec42652542b2ddfe8 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Thu, 25 Jul 2024 03:47:26 +0900 Subject: [PATCH] Add prediction logging for classification and retrieval --- src/jmteb/evaluators/classification/data.py | 7 ++ .../evaluators/classification/evaluator.py | 17 ++++- src/jmteb/evaluators/retrieval/data.py | 26 ++++++++ src/jmteb/evaluators/retrieval/evaluator.py | 64 +++++++++++++++---- .../test_classification_evaluator.py | 19 ++++++ tests/evaluator/test_retrieval_evaluator.py | 19 ++++++ 6 files changed, 138 insertions(+), 14 deletions(-) diff --git a/src/jmteb/evaluators/classification/data.py b/src/jmteb/evaluators/classification/data.py index 5885471..ba5eb8d 100644 --- a/src/jmteb/evaluators/classification/data.py +++ b/src/jmteb/evaluators/classification/data.py @@ -13,6 +13,13 @@ class ClassificationInstance: label: int +@dataclass +class ClassificationPrediction: + text: str + label: int + prediction: int + + class ClassificationDataset(ABC): @abstractmethod def __len__(self): diff --git a/src/jmteb/evaluators/classification/evaluator.py b/src/jmteb/evaluators/classification/evaluator.py index dbe2d8e..457d949 100644 --- a/src/jmteb/evaluators/classification/evaluator.py +++ b/src/jmteb/evaluators/classification/evaluator.py @@ -11,7 +11,7 @@ from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults from .classifiers import Classifier, KnnClassifier, LogRegClassifier -from .data import ClassificationDataset +from .data import ClassificationDataset, ClassificationPrediction class ClassificationEvaluator(EmbeddingEvaluator): @@ -28,6 +28,7 @@ class ClassificationEvaluator(EmbeddingEvaluator): The first one is specified as the main index. classifiers (dict[str, Classifier]): classifiers to be evaluated. prefix (str | None): prefix for sentences. Defaults to None. + log_predictions (bool): whether to log predictions of each datapoint. """ def __init__( @@ -38,6 +39,7 @@ def __init__( average: str = "macro", classifiers: dict[str, Classifier] | None = None, prefix: str | None = None, + log_predictions: bool = False, ) -> None: self.train_dataset = train_dataset self.val_dataset = val_dataset @@ -52,6 +54,7 @@ def __init__( if average_name.strip().lower() in ("micro", "macro", "samples", "weighted", "binary") ] or ["macro"] self.prefix = prefix + self.log_predictions = log_predictions self.main_metric = f"{self.average[0]}_f1" def __call__( @@ -119,6 +122,7 @@ def __call__( "val_scores": val_results, "test_scores": test_results, }, + predictions=self._format_predictions(self.test_dataset, y_pred) if self.log_predictions else None, ) @staticmethod @@ -128,3 +132,14 @@ def _compute_metrics(y_pred: np.ndarray, y_true: list[int], average: list[float] for average_method in average: classifier_results[f"{average_method}_f1"] = f1_score(y_true, y_pred, average=average_method) return classifier_results + + @staticmethod + def _format_predictions(dataset: ClassificationDataset, y_pred: np.ndarray) -> list[ClassificationPrediction]: + texts = [item.text for item in dataset] + y_true = [item.label for item in dataset] + y_pred = y_pred.tolist() + assert len(texts) == len(y_true) == len(y_pred) + return [ + ClassificationPrediction(text=text, label=label, prediction=pred) + for text, label, pred in zip(texts, y_true, y_pred) + ] diff --git a/src/jmteb/evaluators/retrieval/data.py b/src/jmteb/evaluators/retrieval/data.py index 70c69a4..4c8c30b 100644 --- a/src/jmteb/evaluators/retrieval/data.py +++ b/src/jmteb/evaluators/retrieval/data.py @@ -21,6 +21,13 @@ class RetrievalDoc: text: str +@dataclass +class RetrievalPrediction: + query: str + relevant_docs: list[RetrievalDoc] + predicted_relevant_docs: list[RetrievalDoc] + + class RetrievalQueryDataset(ABC): @abstractmethod def __len__(self): @@ -46,6 +53,23 @@ def __getitem__(self, idx) -> RetrievalDoc: def __eq__(self, __value: object) -> bool: return False + def _build_idx_docid_mapping(self, dataset_attr_name: str = "dataset") -> None: + self.idx_to_docid: dict = {} + self.docid_to_idx: dict = {} + id_key: str = getattr(self, "id_key", None) + dataset = getattr(self, dataset_attr_name) + if id_key: + for idx, doc_dict in enumerate(dataset): + self.idx_to_docid[idx] = doc_dict[id_key] + self.docid_to_idx[doc_dict[id_key]] = idx + elif isinstance(dataset[0], RetrievalDoc): + for idx, doc in enumerate(dataset): + doc: RetrievalDoc + self.idx_to_docid[idx] = doc.id + self.docid_to_idx[doc.id] = idx + else: + raise ValueError(f"Invalid dataset type: list[{type(dataset[0])}]") + class HfRetrievalQueryDataset(RetrievalQueryDataset): def __init__( @@ -124,6 +148,7 @@ def __init__(self, path: str, split: str, name: str | None = None, id_key: str = self.dataset = datasets.load_dataset(path, split=split, name=name, trust_remote_code=True) self.id_key = id_key self.text_key = text_key + self._build_idx_docid_mapping() def __len__(self): return len(self.dataset) @@ -150,6 +175,7 @@ def __init__(self, filename: str, id_key: str = "docid", text_key: str = "text") self.dataset = corpus self.id_key = id_key self.text_key = text_key + self._build_idx_docid_mapping() def __len__(self): return len(self.dataset) diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index c7edc59..90549a8 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -15,7 +15,13 @@ from jmteb.embedders.base import TextEmbedder from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults -from .data import RetrievalDocDataset, RetrievalQueryDataset +from .data import ( + RetrievalDoc, + RetrievalDocDataset, + RetrievalPrediction, + RetrievalQuery, + RetrievalQueryDataset, +) T = TypeVar("T") @@ -33,6 +39,7 @@ class RetrievalEvaluator(EmbeddingEvaluator): accuracy_at_k (list[int] | None): accuracy in top k hits. query_prefix (str | None): prefix for queries. Defaults to None. doc_prefix (str | None): prefix for documents. Defaults to None. + log_predictions (bool): whether to log predictions of each datapoint. Defaults to False. """ def __init__( @@ -45,6 +52,7 @@ def __init__( ndcg_at_k: list[int] | None = None, query_prefix: str | None = None, doc_prefix: str | None = None, + log_predictions: bool = False, ) -> None: self.val_query_dataset = val_query_dataset self.test_query_dataset = test_query_dataset @@ -59,6 +67,7 @@ def __init__( self.query_prefix = query_prefix self.doc_prefix = doc_prefix + self.log_predictions = log_predictions def __call__( self, @@ -103,7 +112,7 @@ def __call__( val_results = {} for dist_name, dist_func in dist_functions.items(): - val_results[dist_name] = self._compute_metrics( + val_results[dist_name], _ = self._compute_metrics( query_dataset=self.val_query_dataset, query_embeddings=val_query_embeddings, doc_embeddings=doc_embeddings, @@ -112,14 +121,13 @@ def __call__( sorted_val_results = sorted(val_results.items(), key=lambda res: res[1][self.main_metric], reverse=True) optimal_dist_name = sorted_val_results[0][0] - test_results = { - optimal_dist_name: self._compute_metrics( - query_dataset=self.test_query_dataset, - query_embeddings=test_query_embeddings, - doc_embeddings=doc_embeddings, - dist_func=dist_functions[optimal_dist_name], - ) - } + test_scores, test_predictions = self._compute_metrics( + query_dataset=self.test_query_dataset, + query_embeddings=test_query_embeddings, + doc_embeddings=doc_embeddings, + dist_func=dist_functions[optimal_dist_name], + ) + test_results = {optimal_dist_name: test_scores} return EvaluationResults( metric_name=self.main_metric, @@ -129,6 +137,7 @@ def __call__( "val_scores": val_results, "test_scores": test_results, }, + predictions=test_predictions, ) def _compute_metrics( @@ -137,9 +146,9 @@ def _compute_metrics( query_embeddings: np.ndarray | Tensor, doc_embeddings: np.ndarray | Tensor, dist_func: Callable[[Tensor, Tensor], Tensor], - ) -> dict[str, dict[str, float]]: + ) -> tuple[dict[str, dict[str, float]], list[RetrievalPrediction]]: results: dict[str, float] = {} - + predictions: list[RetrievalPrediction] = [] if self.log_predictions else None with tqdm.tqdm(total=len(doc_embeddings), desc="Retrieval doc chunks") as pbar: top_k_indices_chunks: list[np.ndarray] = [] top_k_scores_chunks: list[np.ndarray] = [] @@ -173,13 +182,42 @@ def _compute_metrics( golden_doc_ids = [item.relevant_docs for item in query_dataset] retrieved_doc_ids = [[self.doc_dataset[i].id for i in indices] for indices in sorted_top_k_indices] + predictions = ( + self._format_predictions(query_dataset, self.doc_dataset, retrieved_doc_ids) + if self.log_predictions + else None + ) + for k in self.accuracy_at_k: results[f"accuracy@{k}"] = accuracy_at_k(golden_doc_ids, retrieved_doc_ids, k) for k in self.ndcg_at_k: results[f"ndcg@{k}"] = ndcg_at_k(golden_doc_ids, retrieved_doc_ids, k) results[f"mrr@{self.max_top_k}"] = mrr_at_k(golden_doc_ids, retrieved_doc_ids, self.max_top_k) - return results + return results, predictions + + @staticmethod + def _format_predictions( + query_dataset: RetrievalQueryDataset, + doc_dataset: RetrievalDocDataset, + retrieved_doc_ids: list[list], + ) -> list[RetrievalPrediction]: + predictions = [] + for q, pred_docids in zip(query_dataset, retrieved_doc_ids): + q: RetrievalQuery + golden_docs: list[RetrievalDoc] = [ + doc_dataset[doc_dataset.docid_to_idx[docid]] for docid in q.relevant_docs + ] + pred_docs: list[RetrievalDoc] = [ + doc_dataset[doc_dataset.docid_to_idx[pred_docid]] for pred_docid in pred_docids + ] + prediction = RetrievalPrediction( + query=q.query, + relevant_docs=golden_docs, + predicted_relevant_docs=pred_docs, + ) + predictions.append(prediction) + return predictions def accuracy_at_k(relevant_docs: list[list[T]], top_hits: list[list[T]], k: int) -> float: diff --git a/tests/evaluator/test_classification_evaluator.py b/tests/evaluator/test_classification_evaluator.py index bce9964..761198e 100644 --- a/tests/evaluator/test_classification_evaluator.py +++ b/tests/evaluator/test_classification_evaluator.py @@ -90,3 +90,22 @@ def test_classification_jsonl_dataset_equal(): assert dummy_jsonl_dataset_1 == dummy_jsonl_dataset_2 dummy_jsonl_dataset_2.label_key = "LABEL" assert dummy_jsonl_dataset_1 != dummy_jsonl_dataset_2 + + +def test_classification_prediction_logging(embedder): + dataset = DummyClassificationDataset() + evaluator = ClassificationEvaluator( + train_dataset=dataset, + val_dataset=dataset, + test_dataset=dataset, + classifiers={ + "logreg": LogRegClassifier(), + "knn": KnnClassifier(k=2, distance_metric="cosine"), + }, + log_predictions=True, + ) + results = evaluator(model=embedder) + assert isinstance(results.predictions, list) + assert [p.text for p in results.predictions] == [d.text for d in dataset] + assert [p.label for p in results.predictions] == [d.label for d in dataset] + assert all([isinstance(p.prediction, int) for p in results.predictions]) diff --git a/tests/evaluator/test_retrieval_evaluator.py b/tests/evaluator/test_retrieval_evaluator.py index fa52c52..d76d65d 100644 --- a/tests/evaluator/test_retrieval_evaluator.py +++ b/tests/evaluator/test_retrieval_evaluator.py @@ -8,6 +8,7 @@ from jmteb.evaluators.retrieval.data import ( JsonlRetrievalDocDataset, JsonlRetrievalQueryDataset, + RetrievalPrediction, ) EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"} @@ -19,6 +20,7 @@ class DummyDocDataset(RetrievalDocDataset): def __init__(self, prefix: str = ""): self._items = [RetrievalDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)] + self._build_idx_docid_mapping("_items") def __len__(self): return len(self._items) @@ -60,6 +62,23 @@ def test_retrieval_evaluator(embedder): assert any(score.startswith(metric) for metric in ["accuracy", "mrr", "ndcg"]) +def test_retrieval_evaluator_with_predictions(embedder): + dummy_query_dataset = DummyQueryDataset() + dummy_doc_dataset = DummyDocDataset() + evaluator = RetrievalEvaluator( + val_query_dataset=dummy_query_dataset, + test_query_dataset=dummy_query_dataset, + doc_dataset=dummy_doc_dataset, + accuracy_at_k=[1, 3, 5, 10], + ndcg_at_k=[1, 3, 5], + doc_chunk_size=3, + log_predictions=True, + ) + results = evaluator(model=embedder) + assert [p.query for p in results.predictions] == [q.query for q in dummy_query_dataset] + assert all([isinstance(p, RetrievalPrediction) for p in results.predictions]) + + def test_retrieval_evaluator_with_prefix(embedder): evaluator_with_prefix = RetrievalEvaluator( val_query_dataset=DummyQueryDataset(),