Skip to content

Commit

Permalink
Add prediction logging for classification and retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Jul 24, 2024
1 parent fe590c1 commit fdc7cdf
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 14 deletions.
7 changes: 7 additions & 0 deletions src/jmteb/evaluators/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ class ClassificationInstance:
label: int


@dataclass
class ClassificationPrediction:
text: str
label: int
prediction: int


class ClassificationDataset(ABC):
@abstractmethod
def __len__(self):
Expand Down
17 changes: 16 additions & 1 deletion src/jmteb/evaluators/classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .classifiers import Classifier, KnnClassifier, LogRegClassifier
from .data import ClassificationDataset
from .data import ClassificationDataset, ClassificationPrediction


class ClassificationEvaluator(EmbeddingEvaluator):
Expand All @@ -28,6 +28,7 @@ class ClassificationEvaluator(EmbeddingEvaluator):
The first one is specified as the main index.
classifiers (dict[str, Classifier]): classifiers to be evaluated.
prefix (str | None): prefix for sentences. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint.
"""

def __init__(
Expand All @@ -38,6 +39,7 @@ def __init__(
average: str = "macro",
classifiers: dict[str, Classifier] | None = None,
prefix: str | None = None,
log_predictions: bool = False,
) -> None:
self.train_dataset = train_dataset
self.val_dataset = val_dataset
Expand All @@ -52,6 +54,7 @@ def __init__(
if average_name.strip().lower() in ("micro", "macro", "samples", "weighted", "binary")
] or ["macro"]
self.prefix = prefix
self.log_predictions = log_predictions
self.main_metric = f"{self.average[0]}_f1"

def __call__(
Expand Down Expand Up @@ -119,6 +122,7 @@ def __call__(
"val_scores": val_results,
"test_scores": test_results,
},
predictions=self._format_predictions(self.test_dataset, y_pred) if self.log_predictions else None,
)

@staticmethod
Expand All @@ -128,3 +132,14 @@ def _compute_metrics(y_pred: np.ndarray, y_true: list[int], average: list[float]
for average_method in average:
classifier_results[f"{average_method}_f1"] = f1_score(y_true, y_pred, average=average_method)
return classifier_results

@staticmethod
def _format_predictions(dataset: ClassificationDataset, y_pred: np.ndarray) -> list[ClassificationPrediction]:
texts = [item.text for item in dataset]
y_true = [item.label for item in dataset]
y_pred = y_pred.tolist()
assert len(texts) == len(y_true) == len(y_pred)
return [
ClassificationPrediction(text=text, label=label, prediction=pred)
for text, label, pred in zip(texts, y_true, y_pred)
]
26 changes: 26 additions & 0 deletions src/jmteb/evaluators/retrieval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ class RetrievalDoc:
text: str


@dataclass
class RetrievalPrediction:
query: str
relevant_docs: list[RetrievalDoc]
predicted_relevant_docs: list[RetrievalDoc]


class RetrievalQueryDataset(ABC):
@abstractmethod
def __len__(self):
Expand All @@ -46,6 +53,23 @@ def __getitem__(self, idx) -> RetrievalDoc:
def __eq__(self, __value: object) -> bool:
return False

def _build_idx_docid_mapping(self, dataset_attr_name: str = "dataset") -> None:
self.idx_to_docid: dict = {}
self.docid_to_idx: dict = {}
id_key: str = getattr(self, "id_key", None)
dataset = getattr(self, dataset_attr_name)
if id_key:
for idx, doc_dict in enumerate(dataset):
self.idx_to_docid[idx] = doc_dict[id_key]
self.docid_to_idx[doc_dict[id_key]] = idx
elif isinstance(dataset[0], RetrievalDoc):
for idx, doc in enumerate(dataset):
doc: RetrievalDoc
self.idx_to_docid[idx] = doc.id
self.docid_to_idx[doc.id] = idx
else:
raise ValueError(f"Invalid dataset type: list[{type(dataset[0])}]")


class HfRetrievalQueryDataset(RetrievalQueryDataset):
def __init__(
Expand Down Expand Up @@ -124,6 +148,7 @@ def __init__(self, path: str, split: str, name: str | None = None, id_key: str =
self.dataset = datasets.load_dataset(path, split=split, name=name, trust_remote_code=True)
self.id_key = id_key
self.text_key = text_key
self._build_idx_docid_mapping()

def __len__(self):
return len(self.dataset)
Expand All @@ -150,6 +175,7 @@ def __init__(self, filename: str, id_key: str = "docid", text_key: str = "text")
self.dataset = corpus
self.id_key = id_key
self.text_key = text_key
self._build_idx_docid_mapping()

def __len__(self):
return len(self.dataset)
Expand Down
64 changes: 51 additions & 13 deletions src/jmteb/evaluators/retrieval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import RetrievalDocDataset, RetrievalQueryDataset
from .data import (
RetrievalDoc,
RetrievalDocDataset,
RetrievalPrediction,
RetrievalQuery,
RetrievalQueryDataset,
)

T = TypeVar("T")

Expand All @@ -33,6 +39,7 @@ class RetrievalEvaluator(EmbeddingEvaluator):
accuracy_at_k (list[int] | None): accuracy in top k hits.
query_prefix (str | None): prefix for queries. Defaults to None.
doc_prefix (str | None): prefix for documents. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint. Defaults to False.
"""

def __init__(
Expand All @@ -45,6 +52,7 @@ def __init__(
ndcg_at_k: list[int] | None = None,
query_prefix: str | None = None,
doc_prefix: str | None = None,
log_predictions: bool = False,
) -> None:
self.val_query_dataset = val_query_dataset
self.test_query_dataset = test_query_dataset
Expand All @@ -59,6 +67,7 @@ def __init__(

self.query_prefix = query_prefix
self.doc_prefix = doc_prefix
self.log_predictions = log_predictions

def __call__(
self,
Expand Down Expand Up @@ -103,7 +112,7 @@ def __call__(

val_results = {}
for dist_name, dist_func in dist_functions.items():
val_results[dist_name] = self._compute_metrics(
val_results[dist_name], _ = self._compute_metrics(
query_dataset=self.val_query_dataset,
query_embeddings=val_query_embeddings,
doc_embeddings=doc_embeddings,
Expand All @@ -112,14 +121,13 @@ def __call__(
sorted_val_results = sorted(val_results.items(), key=lambda res: res[1][self.main_metric], reverse=True)
optimal_dist_name = sorted_val_results[0][0]

test_results = {
optimal_dist_name: self._compute_metrics(
query_dataset=self.test_query_dataset,
query_embeddings=test_query_embeddings,
doc_embeddings=doc_embeddings,
dist_func=dist_functions[optimal_dist_name],
)
}
test_scores, test_predictions = self._compute_metrics(
query_dataset=self.test_query_dataset,
query_embeddings=test_query_embeddings,
doc_embeddings=doc_embeddings,
dist_func=dist_functions[optimal_dist_name],
)
test_results = {optimal_dist_name: test_scores}

return EvaluationResults(
metric_name=self.main_metric,
Expand All @@ -129,6 +137,7 @@ def __call__(
"val_scores": val_results,
"test_scores": test_results,
},
predictions=test_predictions,
)

def _compute_metrics(
Expand All @@ -137,9 +146,9 @@ def _compute_metrics(
query_embeddings: np.ndarray | Tensor,
doc_embeddings: np.ndarray | Tensor,
dist_func: Callable[[Tensor, Tensor], Tensor],
) -> dict[str, dict[str, float]]:
) -> tuple[dict[str, dict[str, float]], list[RetrievalPrediction]]:
results: dict[str, float] = {}

predictions: list[RetrievalPrediction] = [] if self.log_predictions else None
with tqdm.tqdm(total=len(doc_embeddings), desc="Retrieval doc chunks") as pbar:
top_k_indices_chunks: list[np.ndarray] = []
top_k_scores_chunks: list[np.ndarray] = []
Expand Down Expand Up @@ -173,13 +182,42 @@ def _compute_metrics(
golden_doc_ids = [item.relevant_docs for item in query_dataset]
retrieved_doc_ids = [[self.doc_dataset[i].id for i in indices] for indices in sorted_top_k_indices]

predictions = (
self._format_predictions(query_dataset, self.doc_dataset, retrieved_doc_ids)
if self.log_predictions
else None
)

for k in self.accuracy_at_k:
results[f"accuracy@{k}"] = accuracy_at_k(golden_doc_ids, retrieved_doc_ids, k)
for k in self.ndcg_at_k:
results[f"ndcg@{k}"] = ndcg_at_k(golden_doc_ids, retrieved_doc_ids, k)
results[f"mrr@{self.max_top_k}"] = mrr_at_k(golden_doc_ids, retrieved_doc_ids, self.max_top_k)

return results
return results, predictions

@staticmethod
def _format_predictions(
query_dataset: RetrievalQueryDataset,
doc_dataset: RetrievalDocDataset,
retrieved_doc_ids: list[list],
) -> list[RetrievalPrediction]:
predictions = []
for q, pred_docids in zip(query_dataset, retrieved_doc_ids):
q: RetrievalQuery
golden_docs: list[RetrievalDoc] = [
doc_dataset[doc_dataset.docid_to_idx[docid]] for docid in q.relevant_docs
]
pred_docs: list[RetrievalDoc] = [
doc_dataset[doc_dataset.docid_to_idx[pred_docid]] for pred_docid in pred_docids
]
prediction = RetrievalPrediction(
query=q.query,
relevant_docs=golden_docs,
predicted_relevant_docs=pred_docs,
)
predictions.append(prediction)
return predictions


def accuracy_at_k(relevant_docs: list[list[T]], top_hits: list[list[T]], k: int) -> float:
Expand Down
19 changes: 19 additions & 0 deletions tests/evaluator/test_classification_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,22 @@ def test_classification_jsonl_dataset_equal():
assert dummy_jsonl_dataset_1 == dummy_jsonl_dataset_2
dummy_jsonl_dataset_2.label_key = "LABEL"
assert dummy_jsonl_dataset_1 != dummy_jsonl_dataset_2


def test_classification_prediction_logging(embedder):
dataset = DummyClassificationDataset()
evaluator = ClassificationEvaluator(
train_dataset=dataset,
val_dataset=dataset,
test_dataset=dataset,
classifiers={
"logreg": LogRegClassifier(),
"knn": KnnClassifier(k=2, distance_metric="cosine"),
},
log_predictions=True,
)
results = evaluator(model=embedder)
assert isinstance(results.predictions, list)
assert [p.text for p in results.predictions] == [d.text for d in dataset]
assert [p.label for p in results.predictions] == [d.label for d in dataset]
assert all([isinstance(p.prediction, int) for p in results.predictions])
19 changes: 19 additions & 0 deletions tests/evaluator/test_retrieval_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jmteb.evaluators.retrieval.data import (
JsonlRetrievalDocDataset,
JsonlRetrievalQueryDataset,
RetrievalPrediction,
)

EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"}
Expand All @@ -19,6 +20,7 @@
class DummyDocDataset(RetrievalDocDataset):
def __init__(self, prefix: str = ""):
self._items = [RetrievalDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)]
self._build_idx_docid_mapping("_items")

def __len__(self):
return len(self._items)
Expand Down Expand Up @@ -60,6 +62,23 @@ def test_retrieval_evaluator(embedder):
assert any(score.startswith(metric) for metric in ["accuracy", "mrr", "ndcg"])


def test_retrieval_evaluator_with_predictions(embedder):
dummy_query_dataset = DummyQueryDataset()
dummy_doc_dataset = DummyDocDataset()
evaluator = RetrievalEvaluator(
val_query_dataset=dummy_query_dataset,
test_query_dataset=dummy_query_dataset,
doc_dataset=dummy_doc_dataset,
accuracy_at_k=[1, 3, 5, 10],
ndcg_at_k=[1, 3, 5],
doc_chunk_size=3,
log_predictions=True,
)
results = evaluator(model=embedder)
assert [p.query for p in results.predictions] == [q.query for q in dummy_query_dataset]
assert all([isinstance(p, RetrievalPrediction) for p in results.predictions])


def test_retrieval_evaluator_with_prefix(embedder):
evaluator_with_prefix = RetrievalEvaluator(
val_query_dataset=DummyQueryDataset(),
Expand Down

0 comments on commit fdc7cdf

Please sign in to comment.