From e68a27fe519a07054c0e68a747f30f322a1b12cb Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Fri, 27 Sep 2024 17:16:19 +0900 Subject: [PATCH 1/3] Allow encoding kwargs setting in SentenceBert embedder --- src/jmteb/embedders/sbert_embedder.py | 3 ++- src/jmteb/evaluators/classification/evaluator.py | 5 +++++ src/jmteb/evaluators/clustering/evaluator.py | 4 ++++ src/jmteb/evaluators/pair_classification/evaluator.py | 4 ++++ src/jmteb/evaluators/reranking/evaluator.py | 7 +++++++ src/jmteb/evaluators/retrieval/evaluator.py | 7 +++++++ src/jmteb/evaluators/sts/evaluator.py | 4 ++++ 7 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 0188e7d..ba33a36 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -43,7 +43,7 @@ def __init__( else: self.set_output_numpy() - def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: + def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray: if self.add_eos: text = self._add_eos_func(text) return self.model.encode( @@ -54,6 +54,7 @@ def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray batch_size=self.batch_size, device=self.device, normalize_embeddings=self.normalize_embeddings, + **kwargs, ) def _add_eos_func(self, text: str | list[str]) -> str | list[str]: diff --git a/src/jmteb/evaluators/classification/evaluator.py b/src/jmteb/evaluators/classification/evaluator.py index 457d949..dd922d4 100644 --- a/src/jmteb/evaluators/classification/evaluator.py +++ b/src/jmteb/evaluators/classification/evaluator.py @@ -40,6 +40,7 @@ def __init__( classifiers: dict[str, Classifier] | None = None, prefix: str | None = None, log_predictions: bool = False, + encode_kwargs: dict = {}, ) -> None: self.train_dataset = train_dataset self.val_dataset = val_dataset @@ -55,6 +56,7 @@ def __init__( ] or ["macro"] self.prefix = prefix self.log_predictions = log_predictions + self.encode_kwargs = encode_kwargs self.main_metric = f"{self.average[0]}_f1" def __call__( @@ -69,6 +71,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "train_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) y_train = [item.label for item in self.train_dataset] @@ -77,6 +80,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) y_val = [item.label for item in self.val_dataset] @@ -90,6 +94,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) y_test = [item.label for item in self.test_dataset] diff --git a/src/jmteb/evaluators/clustering/evaluator.py b/src/jmteb/evaluators/clustering/evaluator.py index 4f3cd3c..a9da9c9 100644 --- a/src/jmteb/evaluators/clustering/evaluator.py +++ b/src/jmteb/evaluators/clustering/evaluator.py @@ -33,12 +33,14 @@ def __init__( prefix: str | None = None, random_seed: int | None = None, log_predictions: bool = False, + encode_kwargs: dict = {}, ) -> None: self.val_dataset = val_dataset self.test_dataset = test_dataset self.prefix = prefix self.random_seed = random_seed self.log_predictions = log_predictions + self.encode_kwargs = encode_kwargs self.main_metric = "v_measure_score" def __call__( @@ -53,6 +55,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) val_labels = [item.label for item in self.val_dataset] @@ -66,6 +69,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) test_labels = [item.label for item in self.test_dataset] diff --git a/src/jmteb/evaluators/pair_classification/evaluator.py b/src/jmteb/evaluators/pair_classification/evaluator.py index 280bbfb..0de1faa 100644 --- a/src/jmteb/evaluators/pair_classification/evaluator.py +++ b/src/jmteb/evaluators/pair_classification/evaluator.py @@ -32,11 +32,13 @@ def __init__( test_dataset: PairClassificationDataset, sentence1_prefix: str | None = None, sentence2_prefix: str | None = None, + encode_kwargs: dict = {}, ) -> None: self.test_dataset = test_dataset self.val_dataset = val_dataset self.sentence1_prefix = sentence1_prefix self.sentence2_prefix = sentence2_prefix + self.encode_kwargs = encode_kwargs self.metrics = [ThresholdAccuracyMetric(), ThresholdF1Metric()] self.main_metric = "binary_f1" @@ -122,12 +124,14 @@ def _convert_to_embeddings( prefix=self.sentence1_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) embeddings2 = model.batch_encode_with_cache( [item.sentence2 for item in dataset], prefix=self.sentence2_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) golden_labels = [item.label for item in dataset] return embeddings1, embeddings2, golden_labels diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 5c4ba34..7b58bda 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -51,6 +51,8 @@ def __init__( doc_prefix: str | None = None, log_predictions: bool = False, top_n_docs_to_log: int = 5, + query_encode_kwargs: dict = {}, + doc_encode_kwargs: dict = {}, ) -> None: self.test_query_dataset = test_query_dataset self.val_query_dataset = val_query_dataset @@ -61,6 +63,8 @@ def __init__( self.doc_prefix = doc_prefix self.log_predictions = log_predictions self.top_n_docs_to_log = top_n_docs_to_log + self.query_encode_kwargs = query_encode_kwargs + self.doc_encode_kwargs = doc_encode_kwargs def __call__( self, @@ -77,6 +81,7 @@ def __call__( prefix=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.query_encode_kwargs, ) if self.val_query_dataset == self.test_query_dataset: test_query_embeddings = val_query_embeddings @@ -86,12 +91,14 @@ def __call__( prefix=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.query_encode_kwargs, ) doc_embeddings = model.batch_encode_with_cache( text_list=[item.text for item in self.doc_dataset], prefix=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.doc_encode_kwargs, ) logger.info("Start reranking") diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index 73c0981..89a21cd 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -56,6 +56,8 @@ def __init__( doc_prefix: str | None = None, log_predictions: bool = False, top_n_docs_to_log: int = 5, + query_encode_kwargs: dict = {}, + doc_encode_kwargs: dict = {}, ) -> None: self.val_query_dataset = val_query_dataset self.test_query_dataset = test_query_dataset @@ -72,6 +74,8 @@ def __init__( self.doc_prefix = doc_prefix self.log_predictions = log_predictions self.top_n_docs_to_log = top_n_docs_to_log + self.query_encode_kwargs = query_encode_kwargs + self.doc_encode_kwargs = doc_encode_kwargs def __call__( self, @@ -88,6 +92,7 @@ def __call__( prefix=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.query_encode_kwargs, ) if self.val_query_dataset == self.test_query_dataset: test_query_embeddings = val_query_embeddings @@ -97,6 +102,7 @@ def __call__( prefix=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.query_encode_kwargs, ) doc_embeddings = model.batch_encode_with_cache( @@ -104,6 +110,7 @@ def __call__( prefix=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.doc_encode_kwargs, ) logger.info("Start retrieval") diff --git a/src/jmteb/evaluators/sts/evaluator.py b/src/jmteb/evaluators/sts/evaluator.py index b7b8eb8..dcf6b02 100644 --- a/src/jmteb/evaluators/sts/evaluator.py +++ b/src/jmteb/evaluators/sts/evaluator.py @@ -35,6 +35,7 @@ def __init__( sentence1_prefix: str | None = None, sentence2_prefix: str | None = None, log_predictions: bool = False, + encode_kwargs: dict = {}, ) -> None: self.val_dataset = val_dataset self.test_dataset = test_dataset @@ -42,6 +43,7 @@ def __init__( self.sentence2_prefix = sentence2_prefix self.main_metric = "spearman" self.log_predictions = log_predictions + self.encode_kwargs = encode_kwargs def __call__( self, model: TextEmbedder, cache_dir: str | PathLike[str] | None = None, overwrite_cache: bool = False @@ -149,12 +151,14 @@ def _convert_to_embeddings( prefix=self.sentence1_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) embeddings2 = model.batch_encode_with_cache( [item.sentence2 for item in dataset], prefix=self.sentence2_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) device = "cuda" if torch.cuda.is_available() else "cpu" embeddings1 = convert_to_tensor(embeddings1, device) From 82d37a32ca307b654f717953a32c4aaf9fab0880 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Fri, 27 Sep 2024 17:18:04 +0900 Subject: [PATCH 2/3] Modify base embedder --- src/jmteb/embedders/base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py index afefec1..2c93a06 100644 --- a/src/jmteb/embedders/base.py +++ b/src/jmteb/embedders/base.py @@ -19,7 +19,7 @@ class TextEmbedder(ABC): convert_to_numpy: bool _chunk_size: int = 262144 # 2^18 - def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray | torch.Tensor: + def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray | torch.Tensor: """Convert a text string or a list of texts to embedding. Args: @@ -43,6 +43,7 @@ def _batch_encode_and_save_on_disk( prefix: str | None = None, batch_size: int = 262144, dtype: str = "float32", + **kwargs, ) -> np.memmap | torch.Tensor: """ Encode a list of texts and save the embeddings on disk using memmap. @@ -65,7 +66,7 @@ def _batch_encode_and_save_on_disk( with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: for i in range(0, num_samples, batch_size): batch = text_list[i : i + batch_size] - batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix) + batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix, **kwargs) embeddings[i : i + batch_size] = batch_embeddings pbar.update(len(batch)) @@ -83,6 +84,7 @@ def batch_encode_with_cache( cache_path: str | PathLike[str] | None = None, overwrite_cache: bool = False, dtype: str = "float32", + **kwargs, ) -> np.ndarray | torch.Tensor: """ Encode a list of texts and save the embeddings on disk using memmap if cache_path is provided. @@ -95,9 +97,10 @@ def batch_encode_with_cache( dtype (str, optional): data type. Defaults to "float32". """ + logger.warning(f"{kwargs=}") if cache_path is None: logger.info("Encoding embeddings") - return self.encode(text_list, prefix=prefix) + return self.encode(text_list, prefix=prefix, **kwargs) if Path(cache_path).exists() and not overwrite_cache: logger.info(f"Loading embeddings from {cache_path}") @@ -105,7 +108,7 @@ def batch_encode_with_cache( logger.info(f"Encoding and saving embeddings to {cache_path}") embeddings = self._batch_encode_and_save_on_disk( - text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype + text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype, **kwargs ) return embeddings From 94010c76e1a4567b06ae9a3acb46cdf551d9603c Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Wed, 27 Nov 2024 15:59:30 +0900 Subject: [PATCH 3/3] Add docstring for encode_kwargs --- src/jmteb/embedders/base.py | 1 + src/jmteb/evaluators/classification/evaluator.py | 1 + src/jmteb/evaluators/clustering/evaluator.py | 8 ++++++++ src/jmteb/evaluators/pair_classification/evaluator.py | 1 + src/jmteb/evaluators/reranking/evaluator.py | 2 ++ src/jmteb/evaluators/retrieval/evaluator.py | 2 ++ src/jmteb/evaluators/sts/evaluator.py | 1 + 7 files changed, 16 insertions(+) diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py index 2c93a06..ea078f1 100644 --- a/src/jmteb/embedders/base.py +++ b/src/jmteb/embedders/base.py @@ -25,6 +25,7 @@ def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> Args: text (str | list[str]): text string, or a list of texts. prefix (str, optional): the prefix to use for encoding. Default to None. + **kwargs: some more settings that may be necessary for specific models. """ raise NotImplementedError diff --git a/src/jmteb/evaluators/classification/evaluator.py b/src/jmteb/evaluators/classification/evaluator.py index dd922d4..c2b8836 100644 --- a/src/jmteb/evaluators/classification/evaluator.py +++ b/src/jmteb/evaluators/classification/evaluator.py @@ -29,6 +29,7 @@ class ClassificationEvaluator(EmbeddingEvaluator): classifiers (dict[str, Classifier]): classifiers to be evaluated. prefix (str | None): prefix for sentences. Defaults to None. log_predictions (bool): whether to log predictions of each datapoint. + encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}. """ def __init__( diff --git a/src/jmteb/evaluators/clustering/evaluator.py b/src/jmteb/evaluators/clustering/evaluator.py index a9da9c9..2b8cdf2 100644 --- a/src/jmteb/evaluators/clustering/evaluator.py +++ b/src/jmteb/evaluators/clustering/evaluator.py @@ -24,6 +24,14 @@ class ClusteringEvaluator(EmbeddingEvaluator): """ ClusteringEvaluator is a class for evaluating clustering models. + + Args: + val_dataset (ClusteringDataset): validation dataset + test_dataset (ClusteringDataset): evaluation dataset + prefix (str | None): prefix for sentences. Defaults to None. + random_seed (int | None): random seed used in clustering models. Defaults to None. + log_predictions (bool): whether to log predictions of each datapoint. + encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}. """ def __init__( diff --git a/src/jmteb/evaluators/pair_classification/evaluator.py b/src/jmteb/evaluators/pair_classification/evaluator.py index 0de1faa..ef466bf 100644 --- a/src/jmteb/evaluators/pair_classification/evaluator.py +++ b/src/jmteb/evaluators/pair_classification/evaluator.py @@ -22,6 +22,7 @@ class PairClassificationEvaluator(EmbeddingEvaluator): test_dataset (PairClassificationDataset): test dataset sentence1_prefix (str | None): prefix for sentence1. Defaults to None. sentence2_prefix (str | None): prefix for sentence2. Defaults to None. + encode_kwargs (dict): kwargs passed to embedder's encode function. Default to {}. # NOTE: Don't log predictions, as predictions by different metrics could be different. """ diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 7b58bda..144ed36 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -39,6 +39,8 @@ class RerankingEvaluator(EmbeddingEvaluator): doc_prefix (str | None): prefix for documents. Defaults to None. log_predictions (bool): whether to log predictions of each datapoint. Defaults to False. top_n_docs_to_log (int): log only top n documents. Defaults to 5. + query_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding queries. Defaults to {}. + doc_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding documents. Defaults to {}. """ def __init__( diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index 89a21cd..2fd6a21 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -42,6 +42,8 @@ class RetrievalEvaluator(EmbeddingEvaluator): doc_prefix (str | None): prefix for documents. Defaults to None. log_predictions (bool): whether to log predictions of each datapoint. Defaults to False. top_n_docs_to_log (int): log only top n documents that are predicted as relevant. Defaults to 5. + query_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding queries. Defaults to {}. + doc_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding documents. Defaults to {}. """ def __init__( diff --git a/src/jmteb/evaluators/sts/evaluator.py b/src/jmteb/evaluators/sts/evaluator.py index dcf6b02..380ceea 100644 --- a/src/jmteb/evaluators/sts/evaluator.py +++ b/src/jmteb/evaluators/sts/evaluator.py @@ -26,6 +26,7 @@ class STSEvaluator(EmbeddingEvaluator): test_dataset (STSDataset): test dataset sentence1_prefix (str | None): prefix for sentence1. Defaults to None. sentence2_prefix (str | None): prefix for sentence2. Defaults to None. + encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}. """ def __init__(