diff --git a/docs/results/jinaai/jina-embeddings-v3/summary.json b/docs/results/jinaai/jina-embeddings-v3/summary.json new file mode 100644 index 0000000..8524862 --- /dev/null +++ b/docs/results/jinaai/jina-embeddings-v3/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7882733929438857 + }, + "amazon_review_classification": { + "macro_f1": 0.5933239824757218 + }, + "massive_intent_classification": { + "macro_f1": 0.7765343277120157 + }, + "massive_scenario_classification": { + "macro_f1": 0.8974174944345525 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9271464336251287 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.7149884473155108 + }, + "jaqket": { + "ndcg@10": 0.46484206025698144 + }, + "mrtydi": { + "ndcg@10": 0.4544765083850943 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.9843205562446103 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.9561509620323349 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.9385000684351988 + } + }, + "STS": { + "jsick": { + "spearman": 0.781637470000662 + }, + "jsts": { + "spearman": 0.8193234425217734 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5472248713636514 + }, + "mewsc16": { + "v_measure_score": 0.4818974386694296 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.623716814159292 + } + } +} diff --git a/docs/results/sbintuitions/sarashina-embedding-v1-1b/summary.json b/docs/results/sbintuitions/sarashina-embedding-v1-1b/summary.json new file mode 100644 index 0000000..30385ec --- /dev/null +++ b/docs/results/sbintuitions/sarashina-embedding-v1-1b/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7910202863961814 + }, + "amazon_review_classification": { + "macro_f1": 0.614759364446128 + }, + "massive_intent_classification": { + "macro_f1": 0.8225880728874561 + }, + "massive_scenario_classification": { + "macro_f1": 0.9065030576701741 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9374394712541568 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.7168374490004555 + }, + "jaqket": { + "ndcg@10": 0.7279485535689915 + }, + "mrtydi": { + "ndcg@10": 0.41952210141116814 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.9394095717236127 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.9695624263086593 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.8832876426024624 + } + }, + "STS": { + "jsick": { + "spearman": 0.8022484725822061 + }, + "jsts": { + "spearman": 0.851980317221987 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5641831341687762 + }, + "mewsc16": { + "v_measure_score": 0.5129216698739159 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.62 + } + } +} \ No newline at end of file diff --git a/leaderboard.md b/leaderboard.md index 4b05e46..dd64309 100644 --- a/leaderboard.md +++ b/leaderboard.md @@ -7,7 +7,9 @@ The summary shows the average scores within each task. The average score is the | Model | Avg. | Retrieval | STS | Classification | Reranking | Clustering | PairClassification | |:----------------------------------------------|:----------|:------------|:----------|:-----------------|:------------|:-------------|:---------------------| -| OpenAI/text-embedding-3-large | **74.05** | **74.48** | 82.52 | **77.58** | **93.58** | 53.32 | 62.35 | +| sbintuitions/sarashina-embedding-v1-1b | **75.50** | **77.61** | 82.71 | **78.37** | **93.74** | 53.86 | 62.00 | +| OpenAI/text-embedding-3-large | 74.05 | 74.48 | 82.52 | 77.58 | 93.58 | 53.32 | 62.35 | +| jinaai/jina-embeddings-v3 | 73.44 | 75.22 | 80.05 | 76.39 | 92.71 | 51.46 | 62.37 | | cl-nagoya/ruri-large | 73.31 | 73.02 | 83.13 | 77.43 | 92.99 | 51.82 | 62.29 | | pkshatech/GLuCoSE-base-ja-v2 | 72.23 | 73.36 | 82.96 | 74.21 | 93.01 | 48.65 | 62.37 | | pkshatech/RoSEtta-base-ja | 72.04 | 73.21 | 81.39 | 72.41 | 92.69 | 53.23 | 61.74 | @@ -38,11 +40,13 @@ The summary shows the average scores within each task. The average score is the ## Retrieval | Model | Avg. | jagovfaqs_22k
(ndcg@10) | jaqket
(ndcg@10) | mrtydi
(ndcg@10) | nlp_journal_abs_intro
(ndcg@10) | nlp_journal_title_abs
(ndcg@10) | nlp_journal_title_intro
(ndcg@10) | |:----------------------------------------------|:----------|:-----------------------------|:----------------------|:----------------------|:-------------------------------------|:-------------------------------------|:---------------------------------------| -| OpenAI/text-embedding-3-large | **74.48** | 72.41 | 48.21 | 34.88 | **99.33** | 96.55 | **95.47** | -| pkshatech/GLuCoSE-base-ja-v2 | 73.36 | 69.79 | **67.29** | 41.86 | 90.29 | 95.11 | 75.80 | +| sbintuitions/sarashina-embedding-v1-1b | **77.61** | 71.68 | **72.79** | 41.95 | 93.94 | 96.96 | 88.33 | +| jinaai/jina-embeddings-v3 | 75.22 | 71.50 | 46.48 | **45.45** | 98.43 | 95.62 | 93.85 | +| OpenAI/text-embedding-3-large | 74.48 | 72.41 | 48.21 | 34.88 | **99.33** | 96.55 | **95.47** | +| pkshatech/GLuCoSE-base-ja-v2 | 73.36 | 69.79 | 67.29 | 41.86 | 90.29 | 95.11 | 75.80 | | pkshatech/RoSEtta-base-ja | 73.21 | 65.96 | 65.33 | 36.73 | 95.54 | 94.08 | 81.63 | | cl-nagoya/ruri-large | 73.02 | **76.68** | 61.74 | 38.03 | 87.12 | 96.58 | 77.97 | -| intfloat/multilingual-e5-large | 70.98 | 70.30 | 58.78 | **43.63** | 86.00 | 94.70 | 72.48 | +| intfloat/multilingual-e5-large | 70.98 | 70.30 | 58.78 | 43.63 | 86.00 | 94.70 | 72.48 | | cl-nagoya/ruri-base | 69.82 | 74.56 | 50.12 | 35.45 | 86.89 | 96.57 | 75.31 | | cl-nagoya/ruri-small | 69.41 | 73.65 | 48.44 | 33.43 | 87.69 | **97.17** | 76.09 | | intfloat/multilingual-e5-base | 68.21 | 65.34 | 50.67 | 38.38 | 87.10 | 94.73 | 73.05 | @@ -70,15 +74,17 @@ The summary shows the average scores within each task. The average score is the | Model | Avg. | jsick
(spearman) | jsts
(spearman) | |:----------------------------------------------|:----------|:----------------------|:---------------------| | cl-nagoya/sup-simcse-ja-large | **83.18** | 83.80 | 82.57 | -| cl-nagoya/ruri-large | 83.13 | 82.00 | **84.26** | +| cl-nagoya/ruri-large | 83.13 | 82.00 | 84.26 | | pkshatech/GLuCoSE-base-ja-v2 | 82.96 | **84.96** | 80.96 | | cl-nagoya/ruri-base | 82.87 | 82.32 | 83.43 | | cl-nagoya/ruri-small | 82.79 | 83.44 | 82.13 | +| sbintuitions/sarashina-embedding-v1-1b | 82.71 | 80.22 | **85.20** | | OpenAI/text-embedding-3-large | 82.52 | 81.27 | 83.77 | | cl-nagoya/sup-simcse-ja-base | 82.05 | 82.83 | 81.27 | | pkshatech/RoSEtta-base-ja | 81.39 | 83.83 | 78.95 | | cl-nagoya/unsup-simcse-ja-large | 80.56 | 80.15 | 80.98 | | intfloat/multilingual-e5-small | 80.07 | 81.50 | 78.65 | +| jinaai/jina-embeddings-v3 | 80.05 | 78.16 | 81.93 | | intfloat/multilingual-e5-base | 79.84 | 81.28 | 78.39 | | intfloat/multilingual-e5-large | 79.70 | 78.40 | 80.99 | | OpenAI/text-embedding-3-small | 79.46 | 80.83 | 78.08 | @@ -100,10 +106,12 @@ The summary shows the average scores within each task. The average score is the ## Classification | Model | Avg. | amazon_counterfactual
(macro_f1) | amazon_review
(macro_f1) | massive_intent
(macro_f1) | massive_scenario
(macro_f1) | |:----------------------------------------------|:----------|:--------------------------------------|:------------------------------|:-------------------------------|:---------------------------------| -| OpenAI/text-embedding-3-large | **77.58** | 77.90 | **60.44** | 80.91 | **91.08** | +| sbintuitions/sarashina-embedding-v1-1b | **78.37** | 79.10 | **61.48** | 82.26 | 90.65 | +| OpenAI/text-embedding-3-large | 77.58 | 77.90 | 60.44 | 80.91 | **91.08** | | cl-nagoya/ruri-large | 77.43 | 80.81 | 56.80 | **82.56** | 89.56 | | pkshatech/GLuCoSE-base-ja | 76.82 | **82.44** | 58.07 | 78.85 | 87.94 | | oshizo/sbert-jsnli-luke-japanese-base-lite | 76.61 | 79.95 | 57.48 | 80.26 | 88.75 | +| jinaai/jina-embeddings-v3 | 76.39 | 78.83 | 59.33 | 77.65 | 89.74 | | cl-nagoya/ruri-small | 76.22 | 79.92 | 55.61 | 81.49 | 87.88 | | cl-nagoya/ruri-base | 75.58 | 76.66 | 55.76 | 81.41 | 88.49 | | cl-nagoya/unsup-simcse-ja-large | 74.66 | 76.79 | 55.37 | 79.13 | 87.36 | @@ -131,7 +139,8 @@ The summary shows the average scores within each task. The average score is the ## Reranking | Model | Avg. | esci
(ndcg@10) | |:----------------------------------------------|:----------|:--------------------| -| OpenAI/text-embedding-3-large | **93.58** | **93.58** | +| sbintuitions/sarashina-embedding-v1-1b | **93.74** | **93.74** | +| OpenAI/text-embedding-3-large | 93.58 | 93.58 | | OpenAI/text-embedding-ada-002 | 93.04 | 93.04 | | intfloat/multilingual-e5-small | 93.03 | 93.03 | | pkshatech/GLuCoSE-base-ja-v2 | 93.01 | 93.01 | @@ -141,6 +150,7 @@ The summary shows the average scores within each task. The average score is the | OpenAI/text-embedding-3-small | 92.92 | 92.92 | | cl-nagoya/ruri-base | 92.91 | 92.91 | | intfloat/multilingual-e5-base | 92.85 | 92.85 | +| jinaai/jina-embeddings-v3 | 92.71 | 92.71 | | pkshatech/RoSEtta-base-ja | 92.69 | 92.69 | | pkshatech/GLuCoSE-base-ja | 91.90 | 91.90 | | cl-nagoya/sup-simcse-ja-base | 91.83 | 91.83 | @@ -163,10 +173,12 @@ The summary shows the average scores within each task. The average score is the | Model | Avg. | livedoor_news
(v_measure_score) | mewsc16
(v_measure_score) | |:----------------------------------------------|:----------|:-------------------------------------|:-------------------------------| | cl-nagoya/ruri-base | **54.16** | 54.27 | **54.04** | +| sbintuitions/sarashina-embedding-v1-1b | 53.86 | 56.42 | 51.29 | | OpenAI/text-embedding-3-large | 53.32 | 57.09 | 49.55 | | pkshatech/RoSEtta-base-ja | 53.23 | **58.62** | 47.85 | | cl-nagoya/ruri-large | 51.82 | 51.39 | 52.25 | | cl-nagoya/sup-simcse-ja-base | 51.79 | 52.67 | 50.91 | +| jinaai/jina-embeddings-v3 | 51.46 | 54.72 | 48.19 | | intfloat/multilingual-e5-large | 51.24 | 57.13 | 45.34 | | cl-nagoya/ruri-small | 51.19 | 50.96 | 51.41 | | OpenAI/text-embedding-3-small | 51.06 | 54.57 | 47.55 | @@ -204,6 +216,7 @@ The summary shows the average scores within each task. The average score is the | cl-nagoya/ruri-base | 62.38 | 62.38 | | oshizo/sbert-jsnli-luke-japanese-base-lite | 62.38 | 62.38 | | MU-Kindai/Japanese-DiffCSE-BERT-base | 62.38 | 62.38 | +| jinaai/jina-embeddings-v3 | 62.37 | 62.37 | | pkshatech/GLuCoSE-base-ja-v2 | 62.37 | 62.37 | | MU-Kindai/Japanese-SimCSE-BERT-base-sup | 62.37 | 62.37 | | MU-Kindai/Japanese-SimCSE-BERT-large-sup | 62.35 | 62.35 | @@ -219,5 +232,6 @@ The summary shows the average scores within each task. The average score is the | intfloat/multilingual-e5-small | 62.19 | 62.19 | | intfloat/multilingual-e5-large | 62.15 | 62.15 | | cl-nagoya/ruri-small | 62.11 | 62.11 | +| sbintuitions/sarashina-embedding-v1-1b | 62.00 | 62.00 | | pkshatech/RoSEtta-base-ja | 61.74 | 61.74 | diff --git a/pyproject.toml b/pyproject.toml index 743b6bd..b5d2296 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding name = "JMTEB" packages = [{from = "src", include = "jmteb"}] readme = "README.md" -version = "1.3.2" +version = "1.3.3" [tool.poetry.dependencies] python = ">=3.10,<4.0" diff --git a/src/jmteb/__main__.py b/src/jmteb/__main__.py index 55e830a..ff10884 100644 --- a/src/jmteb/__main__.py +++ b/src/jmteb/__main__.py @@ -97,6 +97,18 @@ def main( if len(args.evaluators) == 0: raise ValueError("No evaluator is selected. Please check the config file or the command line arguments.") + # save config as yaml + if args.save_dir: + Path(args.save_dir).mkdir(parents=True, exist_ok=True) + parser.save( + args, + Path(args.save_dir) / "jmteb_config.yaml", + format="yaml", + overwrite=True, + multifile=False, + skip_check=True, + ) + args = parser.instantiate_classes(args) if isinstance(args.evaluators, str): raise ValueError( diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py index afefec1..ea078f1 100644 --- a/src/jmteb/embedders/base.py +++ b/src/jmteb/embedders/base.py @@ -19,12 +19,13 @@ class TextEmbedder(ABC): convert_to_numpy: bool _chunk_size: int = 262144 # 2^18 - def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray | torch.Tensor: + def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray | torch.Tensor: """Convert a text string or a list of texts to embedding. Args: text (str | list[str]): text string, or a list of texts. prefix (str, optional): the prefix to use for encoding. Default to None. + **kwargs: some more settings that may be necessary for specific models. """ raise NotImplementedError @@ -43,6 +44,7 @@ def _batch_encode_and_save_on_disk( prefix: str | None = None, batch_size: int = 262144, dtype: str = "float32", + **kwargs, ) -> np.memmap | torch.Tensor: """ Encode a list of texts and save the embeddings on disk using memmap. @@ -65,7 +67,7 @@ def _batch_encode_and_save_on_disk( with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: for i in range(0, num_samples, batch_size): batch = text_list[i : i + batch_size] - batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix) + batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix, **kwargs) embeddings[i : i + batch_size] = batch_embeddings pbar.update(len(batch)) @@ -83,6 +85,7 @@ def batch_encode_with_cache( cache_path: str | PathLike[str] | None = None, overwrite_cache: bool = False, dtype: str = "float32", + **kwargs, ) -> np.ndarray | torch.Tensor: """ Encode a list of texts and save the embeddings on disk using memmap if cache_path is provided. @@ -95,9 +98,10 @@ def batch_encode_with_cache( dtype (str, optional): data type. Defaults to "float32". """ + logger.warning(f"{kwargs=}") if cache_path is None: logger.info("Encoding embeddings") - return self.encode(text_list, prefix=prefix) + return self.encode(text_list, prefix=prefix, **kwargs) if Path(cache_path).exists() and not overwrite_cache: logger.info(f"Loading embeddings from {cache_path}") @@ -105,7 +109,7 @@ def batch_encode_with_cache( logger.info(f"Encoding and saving embeddings to {cache_path}") embeddings = self._batch_encode_and_save_on_disk( - text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype + text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype, **kwargs ) return embeddings diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 6fb7e87..7416fe4 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -8,6 +8,7 @@ from accelerate.utils import find_executable_batch_size from loguru import logger from sentence_transformers import SentenceTransformer +from sentence_transformers.models import Pooling from sentence_transformers.quantization import quantize_embeddings from sentence_transformers.util import truncate_embeddings from torch import Tensor @@ -85,7 +86,10 @@ def encode( # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling # Tracking the prompt length allow us to remove the prompt during pooling tokenized_prompt = self.sbert.tokenize([prompt]) - if "input_ids" in tokenized_prompt: + + # When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed. + # This prevents problems arising from DataParallel + if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling(): extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 all_embeddings = [] @@ -97,6 +101,13 @@ def encode( features = self.sbert.tokenize(sentences_batch) features.update(extra_features) + # `.gather()` in `.forward()` does not support int type, so make it a type that can gather + # we cast it from an int type to a torch.Tensor type + if "prompt_length" in features and isinstance(features["prompt_length"], int): + batch_size = len(sentences_batch) + prompt_length = torch.Tensor([features["prompt_length"] for _ in range(batch_size)]) + features["prompt_length"] = prompt_length + with torch.no_grad(): out_features = self.forward(features) @@ -156,6 +167,14 @@ def encode( return all_embeddings + # Sentence Transformersの`include_prompt`判定メソッドを参考に実装 + # ref: https://github.com/UKPLab/sentence-transformers/blob/679ab5d38e4cf9cd73d4dcf1cda25ba2ef1ad837/sentence_transformers/trainer.py#L931 # noqa: E501 + def include_prompt_for_pooling(self) -> bool: + for module in self: + if isinstance(module, Pooling): + return module.include_prompt + return True + class DataParallelSentenceBertEmbedder(TextEmbedder): """SentenceBERT embedder with pytorch data parallel""" diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 0188e7d..ba33a36 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -43,7 +43,7 @@ def __init__( else: self.set_output_numpy() - def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: + def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray: if self.add_eos: text = self._add_eos_func(text) return self.model.encode( @@ -54,6 +54,7 @@ def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray batch_size=self.batch_size, device=self.device, normalize_embeddings=self.normalize_embeddings, + **kwargs, ) def _add_eos_func(self, text: str | list[str]) -> str | list[str]: diff --git a/src/jmteb/evaluators/classification/evaluator.py b/src/jmteb/evaluators/classification/evaluator.py index 457d949..c2b8836 100644 --- a/src/jmteb/evaluators/classification/evaluator.py +++ b/src/jmteb/evaluators/classification/evaluator.py @@ -29,6 +29,7 @@ class ClassificationEvaluator(EmbeddingEvaluator): classifiers (dict[str, Classifier]): classifiers to be evaluated. prefix (str | None): prefix for sentences. Defaults to None. log_predictions (bool): whether to log predictions of each datapoint. + encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}. """ def __init__( @@ -40,6 +41,7 @@ def __init__( classifiers: dict[str, Classifier] | None = None, prefix: str | None = None, log_predictions: bool = False, + encode_kwargs: dict = {}, ) -> None: self.train_dataset = train_dataset self.val_dataset = val_dataset @@ -55,6 +57,7 @@ def __init__( ] or ["macro"] self.prefix = prefix self.log_predictions = log_predictions + self.encode_kwargs = encode_kwargs self.main_metric = f"{self.average[0]}_f1" def __call__( @@ -69,6 +72,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "train_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) y_train = [item.label for item in self.train_dataset] @@ -77,6 +81,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) y_val = [item.label for item in self.val_dataset] @@ -90,6 +95,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) y_test = [item.label for item in self.test_dataset] diff --git a/src/jmteb/evaluators/clustering/evaluator.py b/src/jmteb/evaluators/clustering/evaluator.py index 4f3cd3c..2b8cdf2 100644 --- a/src/jmteb/evaluators/clustering/evaluator.py +++ b/src/jmteb/evaluators/clustering/evaluator.py @@ -24,6 +24,14 @@ class ClusteringEvaluator(EmbeddingEvaluator): """ ClusteringEvaluator is a class for evaluating clustering models. + + Args: + val_dataset (ClusteringDataset): validation dataset + test_dataset (ClusteringDataset): evaluation dataset + prefix (str | None): prefix for sentences. Defaults to None. + random_seed (int | None): random seed used in clustering models. Defaults to None. + log_predictions (bool): whether to log predictions of each datapoint. + encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}. """ def __init__( @@ -33,12 +41,14 @@ def __init__( prefix: str | None = None, random_seed: int | None = None, log_predictions: bool = False, + encode_kwargs: dict = {}, ) -> None: self.val_dataset = val_dataset self.test_dataset = test_dataset self.prefix = prefix self.random_seed = random_seed self.log_predictions = log_predictions + self.encode_kwargs = encode_kwargs self.main_metric = "v_measure_score" def __call__( @@ -53,6 +63,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) val_labels = [item.label for item in self.val_dataset] @@ -66,6 +77,7 @@ def __call__( prefix=self.prefix, cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) test_labels = [item.label for item in self.test_dataset] diff --git a/src/jmteb/evaluators/pair_classification/evaluator.py b/src/jmteb/evaluators/pair_classification/evaluator.py index 280bbfb..ef466bf 100644 --- a/src/jmteb/evaluators/pair_classification/evaluator.py +++ b/src/jmteb/evaluators/pair_classification/evaluator.py @@ -22,6 +22,7 @@ class PairClassificationEvaluator(EmbeddingEvaluator): test_dataset (PairClassificationDataset): test dataset sentence1_prefix (str | None): prefix for sentence1. Defaults to None. sentence2_prefix (str | None): prefix for sentence2. Defaults to None. + encode_kwargs (dict): kwargs passed to embedder's encode function. Default to {}. # NOTE: Don't log predictions, as predictions by different metrics could be different. """ @@ -32,11 +33,13 @@ def __init__( test_dataset: PairClassificationDataset, sentence1_prefix: str | None = None, sentence2_prefix: str | None = None, + encode_kwargs: dict = {}, ) -> None: self.test_dataset = test_dataset self.val_dataset = val_dataset self.sentence1_prefix = sentence1_prefix self.sentence2_prefix = sentence2_prefix + self.encode_kwargs = encode_kwargs self.metrics = [ThresholdAccuracyMetric(), ThresholdF1Metric()] self.main_metric = "binary_f1" @@ -122,12 +125,14 @@ def _convert_to_embeddings( prefix=self.sentence1_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) embeddings2 = model.batch_encode_with_cache( [item.sentence2 for item in dataset], prefix=self.sentence2_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) golden_labels = [item.label for item in dataset] return embeddings1, embeddings2, golden_labels diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 5c4ba34..144ed36 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -39,6 +39,8 @@ class RerankingEvaluator(EmbeddingEvaluator): doc_prefix (str | None): prefix for documents. Defaults to None. log_predictions (bool): whether to log predictions of each datapoint. Defaults to False. top_n_docs_to_log (int): log only top n documents. Defaults to 5. + query_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding queries. Defaults to {}. + doc_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding documents. Defaults to {}. """ def __init__( @@ -51,6 +53,8 @@ def __init__( doc_prefix: str | None = None, log_predictions: bool = False, top_n_docs_to_log: int = 5, + query_encode_kwargs: dict = {}, + doc_encode_kwargs: dict = {}, ) -> None: self.test_query_dataset = test_query_dataset self.val_query_dataset = val_query_dataset @@ -61,6 +65,8 @@ def __init__( self.doc_prefix = doc_prefix self.log_predictions = log_predictions self.top_n_docs_to_log = top_n_docs_to_log + self.query_encode_kwargs = query_encode_kwargs + self.doc_encode_kwargs = doc_encode_kwargs def __call__( self, @@ -77,6 +83,7 @@ def __call__( prefix=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.query_encode_kwargs, ) if self.val_query_dataset == self.test_query_dataset: test_query_embeddings = val_query_embeddings @@ -86,12 +93,14 @@ def __call__( prefix=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.query_encode_kwargs, ) doc_embeddings = model.batch_encode_with_cache( text_list=[item.text for item in self.doc_dataset], prefix=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.doc_encode_kwargs, ) logger.info("Start reranking") diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index 73c0981..2fd6a21 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -42,6 +42,8 @@ class RetrievalEvaluator(EmbeddingEvaluator): doc_prefix (str | None): prefix for documents. Defaults to None. log_predictions (bool): whether to log predictions of each datapoint. Defaults to False. top_n_docs_to_log (int): log only top n documents that are predicted as relevant. Defaults to 5. + query_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding queries. Defaults to {}. + doc_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding documents. Defaults to {}. """ def __init__( @@ -56,6 +58,8 @@ def __init__( doc_prefix: str | None = None, log_predictions: bool = False, top_n_docs_to_log: int = 5, + query_encode_kwargs: dict = {}, + doc_encode_kwargs: dict = {}, ) -> None: self.val_query_dataset = val_query_dataset self.test_query_dataset = test_query_dataset @@ -72,6 +76,8 @@ def __init__( self.doc_prefix = doc_prefix self.log_predictions = log_predictions self.top_n_docs_to_log = top_n_docs_to_log + self.query_encode_kwargs = query_encode_kwargs + self.doc_encode_kwargs = doc_encode_kwargs def __call__( self, @@ -88,6 +94,7 @@ def __call__( prefix=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.query_encode_kwargs, ) if self.val_query_dataset == self.test_query_dataset: test_query_embeddings = val_query_embeddings @@ -97,6 +104,7 @@ def __call__( prefix=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.query_encode_kwargs, ) doc_embeddings = model.batch_encode_with_cache( @@ -104,6 +112,7 @@ def __call__( prefix=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.doc_encode_kwargs, ) logger.info("Start retrieval") diff --git a/src/jmteb/evaluators/sts/evaluator.py b/src/jmteb/evaluators/sts/evaluator.py index b7b8eb8..380ceea 100644 --- a/src/jmteb/evaluators/sts/evaluator.py +++ b/src/jmteb/evaluators/sts/evaluator.py @@ -26,6 +26,7 @@ class STSEvaluator(EmbeddingEvaluator): test_dataset (STSDataset): test dataset sentence1_prefix (str | None): prefix for sentence1. Defaults to None. sentence2_prefix (str | None): prefix for sentence2. Defaults to None. + encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}. """ def __init__( @@ -35,6 +36,7 @@ def __init__( sentence1_prefix: str | None = None, sentence2_prefix: str | None = None, log_predictions: bool = False, + encode_kwargs: dict = {}, ) -> None: self.val_dataset = val_dataset self.test_dataset = test_dataset @@ -42,6 +44,7 @@ def __init__( self.sentence2_prefix = sentence2_prefix self.main_metric = "spearman" self.log_predictions = log_predictions + self.encode_kwargs = encode_kwargs def __call__( self, model: TextEmbedder, cache_dir: str | PathLike[str] | None = None, overwrite_cache: bool = False @@ -149,12 +152,14 @@ def _convert_to_embeddings( prefix=self.sentence1_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) embeddings2 = model.batch_encode_with_cache( [item.sentence2 for item in dataset], prefix=self.sentence2_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, + **self.encode_kwargs, ) device = "cuda" if torch.cuda.is_available() else "cpu" embeddings1 = convert_to_tensor(embeddings1, device) diff --git a/tests/embedders/test_dp_sbert.py b/tests/embedders/test_dp_sbert.py index 028e240..e1cd4d2 100644 --- a/tests/embedders/test_dp_sbert.py +++ b/tests/embedders/test_dp_sbert.py @@ -18,6 +18,11 @@ def test_encode(self): assert isinstance(embeddings, np.ndarray) assert embeddings.shape == (OUTPUT_DIM,) + def test_encode_with_prompt(self): + embeddings = self.model.encode("任意のテキスト", prefix="プロンプト") + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (OUTPUT_DIM,) + def test_get_output_dim(self): assert self.model.get_output_dim() == OUTPUT_DIM diff --git a/tests/test_main.py b/tests/test_main.py index ee81fb5..40de1a3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -31,5 +31,5 @@ def test_main_cli(): # fmt: on result = subprocess.run(command) assert result.returncode == 0 - + assert (Path(f) / "jmteb_config.yaml").exists() assert (Path(f) / "summary.json").exists()