diff --git a/docs/results/jinaai/jina-embeddings-v3/summary.json b/docs/results/jinaai/jina-embeddings-v3/summary.json
new file mode 100644
index 0000000..8524862
--- /dev/null
+++ b/docs/results/jinaai/jina-embeddings-v3/summary.json
@@ -0,0 +1,62 @@
+{
+ "Classification": {
+ "amazon_counterfactual_classification": {
+ "macro_f1": 0.7882733929438857
+ },
+ "amazon_review_classification": {
+ "macro_f1": 0.5933239824757218
+ },
+ "massive_intent_classification": {
+ "macro_f1": 0.7765343277120157
+ },
+ "massive_scenario_classification": {
+ "macro_f1": 0.8974174944345525
+ }
+ },
+ "Reranking": {
+ "esci": {
+ "ndcg@10": 0.9271464336251287
+ }
+ },
+ "Retrieval": {
+ "jagovfaqs_22k": {
+ "ndcg@10": 0.7149884473155108
+ },
+ "jaqket": {
+ "ndcg@10": 0.46484206025698144
+ },
+ "mrtydi": {
+ "ndcg@10": 0.4544765083850943
+ },
+ "nlp_journal_abs_intro": {
+ "ndcg@10": 0.9843205562446103
+ },
+ "nlp_journal_title_abs": {
+ "ndcg@10": 0.9561509620323349
+ },
+ "nlp_journal_title_intro": {
+ "ndcg@10": 0.9385000684351988
+ }
+ },
+ "STS": {
+ "jsick": {
+ "spearman": 0.781637470000662
+ },
+ "jsts": {
+ "spearman": 0.8193234425217734
+ }
+ },
+ "Clustering": {
+ "livedoor_news": {
+ "v_measure_score": 0.5472248713636514
+ },
+ "mewsc16": {
+ "v_measure_score": 0.4818974386694296
+ }
+ },
+ "PairClassification": {
+ "paws_x_ja": {
+ "binary_f1": 0.623716814159292
+ }
+ }
+}
diff --git a/docs/results/sbintuitions/sarashina-embedding-v1-1b/summary.json b/docs/results/sbintuitions/sarashina-embedding-v1-1b/summary.json
new file mode 100644
index 0000000..30385ec
--- /dev/null
+++ b/docs/results/sbintuitions/sarashina-embedding-v1-1b/summary.json
@@ -0,0 +1,62 @@
+{
+ "Classification": {
+ "amazon_counterfactual_classification": {
+ "macro_f1": 0.7910202863961814
+ },
+ "amazon_review_classification": {
+ "macro_f1": 0.614759364446128
+ },
+ "massive_intent_classification": {
+ "macro_f1": 0.8225880728874561
+ },
+ "massive_scenario_classification": {
+ "macro_f1": 0.9065030576701741
+ }
+ },
+ "Reranking": {
+ "esci": {
+ "ndcg@10": 0.9374394712541568
+ }
+ },
+ "Retrieval": {
+ "jagovfaqs_22k": {
+ "ndcg@10": 0.7168374490004555
+ },
+ "jaqket": {
+ "ndcg@10": 0.7279485535689915
+ },
+ "mrtydi": {
+ "ndcg@10": 0.41952210141116814
+ },
+ "nlp_journal_abs_intro": {
+ "ndcg@10": 0.9394095717236127
+ },
+ "nlp_journal_title_abs": {
+ "ndcg@10": 0.9695624263086593
+ },
+ "nlp_journal_title_intro": {
+ "ndcg@10": 0.8832876426024624
+ }
+ },
+ "STS": {
+ "jsick": {
+ "spearman": 0.8022484725822061
+ },
+ "jsts": {
+ "spearman": 0.851980317221987
+ }
+ },
+ "Clustering": {
+ "livedoor_news": {
+ "v_measure_score": 0.5641831341687762
+ },
+ "mewsc16": {
+ "v_measure_score": 0.5129216698739159
+ }
+ },
+ "PairClassification": {
+ "paws_x_ja": {
+ "binary_f1": 0.62
+ }
+ }
+}
\ No newline at end of file
diff --git a/leaderboard.md b/leaderboard.md
index 4b05e46..dd64309 100644
--- a/leaderboard.md
+++ b/leaderboard.md
@@ -7,7 +7,9 @@ The summary shows the average scores within each task. The average score is the
| Model | Avg. | Retrieval | STS | Classification | Reranking | Clustering | PairClassification |
|:----------------------------------------------|:----------|:------------|:----------|:-----------------|:------------|:-------------|:---------------------|
-| OpenAI/text-embedding-3-large | **74.05** | **74.48** | 82.52 | **77.58** | **93.58** | 53.32 | 62.35 |
+| sbintuitions/sarashina-embedding-v1-1b | **75.50** | **77.61** | 82.71 | **78.37** | **93.74** | 53.86 | 62.00 |
+| OpenAI/text-embedding-3-large | 74.05 | 74.48 | 82.52 | 77.58 | 93.58 | 53.32 | 62.35 |
+| jinaai/jina-embeddings-v3 | 73.44 | 75.22 | 80.05 | 76.39 | 92.71 | 51.46 | 62.37 |
| cl-nagoya/ruri-large | 73.31 | 73.02 | 83.13 | 77.43 | 92.99 | 51.82 | 62.29 |
| pkshatech/GLuCoSE-base-ja-v2 | 72.23 | 73.36 | 82.96 | 74.21 | 93.01 | 48.65 | 62.37 |
| pkshatech/RoSEtta-base-ja | 72.04 | 73.21 | 81.39 | 72.41 | 92.69 | 53.23 | 61.74 |
@@ -38,11 +40,13 @@ The summary shows the average scores within each task. The average score is the
## Retrieval
| Model | Avg. | jagovfaqs_22k
(ndcg@10) | jaqket
(ndcg@10) | mrtydi
(ndcg@10) | nlp_journal_abs_intro
(ndcg@10) | nlp_journal_title_abs
(ndcg@10) | nlp_journal_title_intro
(ndcg@10) |
|:----------------------------------------------|:----------|:-----------------------------|:----------------------|:----------------------|:-------------------------------------|:-------------------------------------|:---------------------------------------|
-| OpenAI/text-embedding-3-large | **74.48** | 72.41 | 48.21 | 34.88 | **99.33** | 96.55 | **95.47** |
-| pkshatech/GLuCoSE-base-ja-v2 | 73.36 | 69.79 | **67.29** | 41.86 | 90.29 | 95.11 | 75.80 |
+| sbintuitions/sarashina-embedding-v1-1b | **77.61** | 71.68 | **72.79** | 41.95 | 93.94 | 96.96 | 88.33 |
+| jinaai/jina-embeddings-v3 | 75.22 | 71.50 | 46.48 | **45.45** | 98.43 | 95.62 | 93.85 |
+| OpenAI/text-embedding-3-large | 74.48 | 72.41 | 48.21 | 34.88 | **99.33** | 96.55 | **95.47** |
+| pkshatech/GLuCoSE-base-ja-v2 | 73.36 | 69.79 | 67.29 | 41.86 | 90.29 | 95.11 | 75.80 |
| pkshatech/RoSEtta-base-ja | 73.21 | 65.96 | 65.33 | 36.73 | 95.54 | 94.08 | 81.63 |
| cl-nagoya/ruri-large | 73.02 | **76.68** | 61.74 | 38.03 | 87.12 | 96.58 | 77.97 |
-| intfloat/multilingual-e5-large | 70.98 | 70.30 | 58.78 | **43.63** | 86.00 | 94.70 | 72.48 |
+| intfloat/multilingual-e5-large | 70.98 | 70.30 | 58.78 | 43.63 | 86.00 | 94.70 | 72.48 |
| cl-nagoya/ruri-base | 69.82 | 74.56 | 50.12 | 35.45 | 86.89 | 96.57 | 75.31 |
| cl-nagoya/ruri-small | 69.41 | 73.65 | 48.44 | 33.43 | 87.69 | **97.17** | 76.09 |
| intfloat/multilingual-e5-base | 68.21 | 65.34 | 50.67 | 38.38 | 87.10 | 94.73 | 73.05 |
@@ -70,15 +74,17 @@ The summary shows the average scores within each task. The average score is the
| Model | Avg. | jsick
(spearman) | jsts
(spearman) |
|:----------------------------------------------|:----------|:----------------------|:---------------------|
| cl-nagoya/sup-simcse-ja-large | **83.18** | 83.80 | 82.57 |
-| cl-nagoya/ruri-large | 83.13 | 82.00 | **84.26** |
+| cl-nagoya/ruri-large | 83.13 | 82.00 | 84.26 |
| pkshatech/GLuCoSE-base-ja-v2 | 82.96 | **84.96** | 80.96 |
| cl-nagoya/ruri-base | 82.87 | 82.32 | 83.43 |
| cl-nagoya/ruri-small | 82.79 | 83.44 | 82.13 |
+| sbintuitions/sarashina-embedding-v1-1b | 82.71 | 80.22 | **85.20** |
| OpenAI/text-embedding-3-large | 82.52 | 81.27 | 83.77 |
| cl-nagoya/sup-simcse-ja-base | 82.05 | 82.83 | 81.27 |
| pkshatech/RoSEtta-base-ja | 81.39 | 83.83 | 78.95 |
| cl-nagoya/unsup-simcse-ja-large | 80.56 | 80.15 | 80.98 |
| intfloat/multilingual-e5-small | 80.07 | 81.50 | 78.65 |
+| jinaai/jina-embeddings-v3 | 80.05 | 78.16 | 81.93 |
| intfloat/multilingual-e5-base | 79.84 | 81.28 | 78.39 |
| intfloat/multilingual-e5-large | 79.70 | 78.40 | 80.99 |
| OpenAI/text-embedding-3-small | 79.46 | 80.83 | 78.08 |
@@ -100,10 +106,12 @@ The summary shows the average scores within each task. The average score is the
## Classification
| Model | Avg. | amazon_counterfactual
(macro_f1) | amazon_review
(macro_f1) | massive_intent
(macro_f1) | massive_scenario
(macro_f1) |
|:----------------------------------------------|:----------|:--------------------------------------|:------------------------------|:-------------------------------|:---------------------------------|
-| OpenAI/text-embedding-3-large | **77.58** | 77.90 | **60.44** | 80.91 | **91.08** |
+| sbintuitions/sarashina-embedding-v1-1b | **78.37** | 79.10 | **61.48** | 82.26 | 90.65 |
+| OpenAI/text-embedding-3-large | 77.58 | 77.90 | 60.44 | 80.91 | **91.08** |
| cl-nagoya/ruri-large | 77.43 | 80.81 | 56.80 | **82.56** | 89.56 |
| pkshatech/GLuCoSE-base-ja | 76.82 | **82.44** | 58.07 | 78.85 | 87.94 |
| oshizo/sbert-jsnli-luke-japanese-base-lite | 76.61 | 79.95 | 57.48 | 80.26 | 88.75 |
+| jinaai/jina-embeddings-v3 | 76.39 | 78.83 | 59.33 | 77.65 | 89.74 |
| cl-nagoya/ruri-small | 76.22 | 79.92 | 55.61 | 81.49 | 87.88 |
| cl-nagoya/ruri-base | 75.58 | 76.66 | 55.76 | 81.41 | 88.49 |
| cl-nagoya/unsup-simcse-ja-large | 74.66 | 76.79 | 55.37 | 79.13 | 87.36 |
@@ -131,7 +139,8 @@ The summary shows the average scores within each task. The average score is the
## Reranking
| Model | Avg. | esci
(ndcg@10) |
|:----------------------------------------------|:----------|:--------------------|
-| OpenAI/text-embedding-3-large | **93.58** | **93.58** |
+| sbintuitions/sarashina-embedding-v1-1b | **93.74** | **93.74** |
+| OpenAI/text-embedding-3-large | 93.58 | 93.58 |
| OpenAI/text-embedding-ada-002 | 93.04 | 93.04 |
| intfloat/multilingual-e5-small | 93.03 | 93.03 |
| pkshatech/GLuCoSE-base-ja-v2 | 93.01 | 93.01 |
@@ -141,6 +150,7 @@ The summary shows the average scores within each task. The average score is the
| OpenAI/text-embedding-3-small | 92.92 | 92.92 |
| cl-nagoya/ruri-base | 92.91 | 92.91 |
| intfloat/multilingual-e5-base | 92.85 | 92.85 |
+| jinaai/jina-embeddings-v3 | 92.71 | 92.71 |
| pkshatech/RoSEtta-base-ja | 92.69 | 92.69 |
| pkshatech/GLuCoSE-base-ja | 91.90 | 91.90 |
| cl-nagoya/sup-simcse-ja-base | 91.83 | 91.83 |
@@ -163,10 +173,12 @@ The summary shows the average scores within each task. The average score is the
| Model | Avg. | livedoor_news
(v_measure_score) | mewsc16
(v_measure_score) |
|:----------------------------------------------|:----------|:-------------------------------------|:-------------------------------|
| cl-nagoya/ruri-base | **54.16** | 54.27 | **54.04** |
+| sbintuitions/sarashina-embedding-v1-1b | 53.86 | 56.42 | 51.29 |
| OpenAI/text-embedding-3-large | 53.32 | 57.09 | 49.55 |
| pkshatech/RoSEtta-base-ja | 53.23 | **58.62** | 47.85 |
| cl-nagoya/ruri-large | 51.82 | 51.39 | 52.25 |
| cl-nagoya/sup-simcse-ja-base | 51.79 | 52.67 | 50.91 |
+| jinaai/jina-embeddings-v3 | 51.46 | 54.72 | 48.19 |
| intfloat/multilingual-e5-large | 51.24 | 57.13 | 45.34 |
| cl-nagoya/ruri-small | 51.19 | 50.96 | 51.41 |
| OpenAI/text-embedding-3-small | 51.06 | 54.57 | 47.55 |
@@ -204,6 +216,7 @@ The summary shows the average scores within each task. The average score is the
| cl-nagoya/ruri-base | 62.38 | 62.38 |
| oshizo/sbert-jsnli-luke-japanese-base-lite | 62.38 | 62.38 |
| MU-Kindai/Japanese-DiffCSE-BERT-base | 62.38 | 62.38 |
+| jinaai/jina-embeddings-v3 | 62.37 | 62.37 |
| pkshatech/GLuCoSE-base-ja-v2 | 62.37 | 62.37 |
| MU-Kindai/Japanese-SimCSE-BERT-base-sup | 62.37 | 62.37 |
| MU-Kindai/Japanese-SimCSE-BERT-large-sup | 62.35 | 62.35 |
@@ -219,5 +232,6 @@ The summary shows the average scores within each task. The average score is the
| intfloat/multilingual-e5-small | 62.19 | 62.19 |
| intfloat/multilingual-e5-large | 62.15 | 62.15 |
| cl-nagoya/ruri-small | 62.11 | 62.11 |
+| sbintuitions/sarashina-embedding-v1-1b | 62.00 | 62.00 |
| pkshatech/RoSEtta-base-ja | 61.74 | 61.74 |
diff --git a/pyproject.toml b/pyproject.toml
index 743b6bd..b5d2296 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,7 +12,7 @@ description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding
name = "JMTEB"
packages = [{from = "src", include = "jmteb"}]
readme = "README.md"
-version = "1.3.2"
+version = "1.3.3"
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
diff --git a/src/jmteb/__main__.py b/src/jmteb/__main__.py
index 55e830a..ff10884 100644
--- a/src/jmteb/__main__.py
+++ b/src/jmteb/__main__.py
@@ -97,6 +97,18 @@ def main(
if len(args.evaluators) == 0:
raise ValueError("No evaluator is selected. Please check the config file or the command line arguments.")
+ # save config as yaml
+ if args.save_dir:
+ Path(args.save_dir).mkdir(parents=True, exist_ok=True)
+ parser.save(
+ args,
+ Path(args.save_dir) / "jmteb_config.yaml",
+ format="yaml",
+ overwrite=True,
+ multifile=False,
+ skip_check=True,
+ )
+
args = parser.instantiate_classes(args)
if isinstance(args.evaluators, str):
raise ValueError(
diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py
index afefec1..ea078f1 100644
--- a/src/jmteb/embedders/base.py
+++ b/src/jmteb/embedders/base.py
@@ -19,12 +19,13 @@ class TextEmbedder(ABC):
convert_to_numpy: bool
_chunk_size: int = 262144 # 2^18
- def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray | torch.Tensor:
+ def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray | torch.Tensor:
"""Convert a text string or a list of texts to embedding.
Args:
text (str | list[str]): text string, or a list of texts.
prefix (str, optional): the prefix to use for encoding. Default to None.
+ **kwargs: some more settings that may be necessary for specific models.
"""
raise NotImplementedError
@@ -43,6 +44,7 @@ def _batch_encode_and_save_on_disk(
prefix: str | None = None,
batch_size: int = 262144,
dtype: str = "float32",
+ **kwargs,
) -> np.memmap | torch.Tensor:
"""
Encode a list of texts and save the embeddings on disk using memmap.
@@ -65,7 +67,7 @@ def _batch_encode_and_save_on_disk(
with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar:
for i in range(0, num_samples, batch_size):
batch = text_list[i : i + batch_size]
- batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix)
+ batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix, **kwargs)
embeddings[i : i + batch_size] = batch_embeddings
pbar.update(len(batch))
@@ -83,6 +85,7 @@ def batch_encode_with_cache(
cache_path: str | PathLike[str] | None = None,
overwrite_cache: bool = False,
dtype: str = "float32",
+ **kwargs,
) -> np.ndarray | torch.Tensor:
"""
Encode a list of texts and save the embeddings on disk using memmap if cache_path is provided.
@@ -95,9 +98,10 @@ def batch_encode_with_cache(
dtype (str, optional): data type. Defaults to "float32".
"""
+ logger.warning(f"{kwargs=}")
if cache_path is None:
logger.info("Encoding embeddings")
- return self.encode(text_list, prefix=prefix)
+ return self.encode(text_list, prefix=prefix, **kwargs)
if Path(cache_path).exists() and not overwrite_cache:
logger.info(f"Loading embeddings from {cache_path}")
@@ -105,7 +109,7 @@ def batch_encode_with_cache(
logger.info(f"Encoding and saving embeddings to {cache_path}")
embeddings = self._batch_encode_and_save_on_disk(
- text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype
+ text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype, **kwargs
)
return embeddings
diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py
index 6fb7e87..7416fe4 100644
--- a/src/jmteb/embedders/data_parallel_sbert_embedder.py
+++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py
@@ -8,6 +8,7 @@
from accelerate.utils import find_executable_batch_size
from loguru import logger
from sentence_transformers import SentenceTransformer
+from sentence_transformers.models import Pooling
from sentence_transformers.quantization import quantize_embeddings
from sentence_transformers.util import truncate_embeddings
from torch import Tensor
@@ -85,7 +86,10 @@ def encode(
# Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
# Tracking the prompt length allow us to remove the prompt during pooling
tokenized_prompt = self.sbert.tokenize([prompt])
- if "input_ids" in tokenized_prompt:
+
+ # When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed.
+ # This prevents problems arising from DataParallel
+ if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling():
extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1
all_embeddings = []
@@ -97,6 +101,13 @@ def encode(
features = self.sbert.tokenize(sentences_batch)
features.update(extra_features)
+ # `.gather()` in `.forward()` does not support int type, so make it a type that can gather
+ # we cast it from an int type to a torch.Tensor type
+ if "prompt_length" in features and isinstance(features["prompt_length"], int):
+ batch_size = len(sentences_batch)
+ prompt_length = torch.Tensor([features["prompt_length"] for _ in range(batch_size)])
+ features["prompt_length"] = prompt_length
+
with torch.no_grad():
out_features = self.forward(features)
@@ -156,6 +167,14 @@ def encode(
return all_embeddings
+ # Sentence Transformersの`include_prompt`判定メソッドを参考に実装
+ # ref: https://github.com/UKPLab/sentence-transformers/blob/679ab5d38e4cf9cd73d4dcf1cda25ba2ef1ad837/sentence_transformers/trainer.py#L931 # noqa: E501
+ def include_prompt_for_pooling(self) -> bool:
+ for module in self:
+ if isinstance(module, Pooling):
+ return module.include_prompt
+ return True
+
class DataParallelSentenceBertEmbedder(TextEmbedder):
"""SentenceBERT embedder with pytorch data parallel"""
diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py
index 0188e7d..ba33a36 100644
--- a/src/jmteb/embedders/sbert_embedder.py
+++ b/src/jmteb/embedders/sbert_embedder.py
@@ -43,7 +43,7 @@ def __init__(
else:
self.set_output_numpy()
- def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
+ def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray:
if self.add_eos:
text = self._add_eos_func(text)
return self.model.encode(
@@ -54,6 +54,7 @@ def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray
batch_size=self.batch_size,
device=self.device,
normalize_embeddings=self.normalize_embeddings,
+ **kwargs,
)
def _add_eos_func(self, text: str | list[str]) -> str | list[str]:
diff --git a/src/jmteb/evaluators/classification/evaluator.py b/src/jmteb/evaluators/classification/evaluator.py
index 457d949..c2b8836 100644
--- a/src/jmteb/evaluators/classification/evaluator.py
+++ b/src/jmteb/evaluators/classification/evaluator.py
@@ -29,6 +29,7 @@ class ClassificationEvaluator(EmbeddingEvaluator):
classifiers (dict[str, Classifier]): classifiers to be evaluated.
prefix (str | None): prefix for sentences. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint.
+ encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}.
"""
def __init__(
@@ -40,6 +41,7 @@ def __init__(
classifiers: dict[str, Classifier] | None = None,
prefix: str | None = None,
log_predictions: bool = False,
+ encode_kwargs: dict = {},
) -> None:
self.train_dataset = train_dataset
self.val_dataset = val_dataset
@@ -55,6 +57,7 @@ def __init__(
] or ["macro"]
self.prefix = prefix
self.log_predictions = log_predictions
+ self.encode_kwargs = encode_kwargs
self.main_metric = f"{self.average[0]}_f1"
def __call__(
@@ -69,6 +72,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "train_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
y_train = [item.label for item in self.train_dataset]
@@ -77,6 +81,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
y_val = [item.label for item in self.val_dataset]
@@ -90,6 +95,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
y_test = [item.label for item in self.test_dataset]
diff --git a/src/jmteb/evaluators/clustering/evaluator.py b/src/jmteb/evaluators/clustering/evaluator.py
index 4f3cd3c..2b8cdf2 100644
--- a/src/jmteb/evaluators/clustering/evaluator.py
+++ b/src/jmteb/evaluators/clustering/evaluator.py
@@ -24,6 +24,14 @@
class ClusteringEvaluator(EmbeddingEvaluator):
"""
ClusteringEvaluator is a class for evaluating clustering models.
+
+ Args:
+ val_dataset (ClusteringDataset): validation dataset
+ test_dataset (ClusteringDataset): evaluation dataset
+ prefix (str | None): prefix for sentences. Defaults to None.
+ random_seed (int | None): random seed used in clustering models. Defaults to None.
+ log_predictions (bool): whether to log predictions of each datapoint.
+ encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}.
"""
def __init__(
@@ -33,12 +41,14 @@ def __init__(
prefix: str | None = None,
random_seed: int | None = None,
log_predictions: bool = False,
+ encode_kwargs: dict = {},
) -> None:
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.prefix = prefix
self.random_seed = random_seed
self.log_predictions = log_predictions
+ self.encode_kwargs = encode_kwargs
self.main_metric = "v_measure_score"
def __call__(
@@ -53,6 +63,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
val_labels = [item.label for item in self.val_dataset]
@@ -66,6 +77,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
test_labels = [item.label for item in self.test_dataset]
diff --git a/src/jmteb/evaluators/pair_classification/evaluator.py b/src/jmteb/evaluators/pair_classification/evaluator.py
index 280bbfb..ef466bf 100644
--- a/src/jmteb/evaluators/pair_classification/evaluator.py
+++ b/src/jmteb/evaluators/pair_classification/evaluator.py
@@ -22,6 +22,7 @@ class PairClassificationEvaluator(EmbeddingEvaluator):
test_dataset (PairClassificationDataset): test dataset
sentence1_prefix (str | None): prefix for sentence1. Defaults to None.
sentence2_prefix (str | None): prefix for sentence2. Defaults to None.
+ encode_kwargs (dict): kwargs passed to embedder's encode function. Default to {}.
# NOTE: Don't log predictions, as predictions by different metrics could be different.
"""
@@ -32,11 +33,13 @@ def __init__(
test_dataset: PairClassificationDataset,
sentence1_prefix: str | None = None,
sentence2_prefix: str | None = None,
+ encode_kwargs: dict = {},
) -> None:
self.test_dataset = test_dataset
self.val_dataset = val_dataset
self.sentence1_prefix = sentence1_prefix
self.sentence2_prefix = sentence2_prefix
+ self.encode_kwargs = encode_kwargs
self.metrics = [ThresholdAccuracyMetric(), ThresholdF1Metric()]
self.main_metric = "binary_f1"
@@ -122,12 +125,14 @@ def _convert_to_embeddings(
prefix=self.sentence1_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
embeddings2 = model.batch_encode_with_cache(
[item.sentence2 for item in dataset],
prefix=self.sentence2_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
golden_labels = [item.label for item in dataset]
return embeddings1, embeddings2, golden_labels
diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py
index 5c4ba34..144ed36 100644
--- a/src/jmteb/evaluators/reranking/evaluator.py
+++ b/src/jmteb/evaluators/reranking/evaluator.py
@@ -39,6 +39,8 @@ class RerankingEvaluator(EmbeddingEvaluator):
doc_prefix (str | None): prefix for documents. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint. Defaults to False.
top_n_docs_to_log (int): log only top n documents. Defaults to 5.
+ query_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding queries. Defaults to {}.
+ doc_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding documents. Defaults to {}.
"""
def __init__(
@@ -51,6 +53,8 @@ def __init__(
doc_prefix: str | None = None,
log_predictions: bool = False,
top_n_docs_to_log: int = 5,
+ query_encode_kwargs: dict = {},
+ doc_encode_kwargs: dict = {},
) -> None:
self.test_query_dataset = test_query_dataset
self.val_query_dataset = val_query_dataset
@@ -61,6 +65,8 @@ def __init__(
self.doc_prefix = doc_prefix
self.log_predictions = log_predictions
self.top_n_docs_to_log = top_n_docs_to_log
+ self.query_encode_kwargs = query_encode_kwargs
+ self.doc_encode_kwargs = doc_encode_kwargs
def __call__(
self,
@@ -77,6 +83,7 @@ def __call__(
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.query_encode_kwargs,
)
if self.val_query_dataset == self.test_query_dataset:
test_query_embeddings = val_query_embeddings
@@ -86,12 +93,14 @@ def __call__(
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.query_encode_kwargs,
)
doc_embeddings = model.batch_encode_with_cache(
text_list=[item.text for item in self.doc_dataset],
prefix=self.doc_prefix,
cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.doc_encode_kwargs,
)
logger.info("Start reranking")
diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py
index 73c0981..2fd6a21 100644
--- a/src/jmteb/evaluators/retrieval/evaluator.py
+++ b/src/jmteb/evaluators/retrieval/evaluator.py
@@ -42,6 +42,8 @@ class RetrievalEvaluator(EmbeddingEvaluator):
doc_prefix (str | None): prefix for documents. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint. Defaults to False.
top_n_docs_to_log (int): log only top n documents that are predicted as relevant. Defaults to 5.
+ query_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding queries. Defaults to {}.
+ doc_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding documents. Defaults to {}.
"""
def __init__(
@@ -56,6 +58,8 @@ def __init__(
doc_prefix: str | None = None,
log_predictions: bool = False,
top_n_docs_to_log: int = 5,
+ query_encode_kwargs: dict = {},
+ doc_encode_kwargs: dict = {},
) -> None:
self.val_query_dataset = val_query_dataset
self.test_query_dataset = test_query_dataset
@@ -72,6 +76,8 @@ def __init__(
self.doc_prefix = doc_prefix
self.log_predictions = log_predictions
self.top_n_docs_to_log = top_n_docs_to_log
+ self.query_encode_kwargs = query_encode_kwargs
+ self.doc_encode_kwargs = doc_encode_kwargs
def __call__(
self,
@@ -88,6 +94,7 @@ def __call__(
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.query_encode_kwargs,
)
if self.val_query_dataset == self.test_query_dataset:
test_query_embeddings = val_query_embeddings
@@ -97,6 +104,7 @@ def __call__(
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.query_encode_kwargs,
)
doc_embeddings = model.batch_encode_with_cache(
@@ -104,6 +112,7 @@ def __call__(
prefix=self.doc_prefix,
cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.doc_encode_kwargs,
)
logger.info("Start retrieval")
diff --git a/src/jmteb/evaluators/sts/evaluator.py b/src/jmteb/evaluators/sts/evaluator.py
index b7b8eb8..380ceea 100644
--- a/src/jmteb/evaluators/sts/evaluator.py
+++ b/src/jmteb/evaluators/sts/evaluator.py
@@ -26,6 +26,7 @@ class STSEvaluator(EmbeddingEvaluator):
test_dataset (STSDataset): test dataset
sentence1_prefix (str | None): prefix for sentence1. Defaults to None.
sentence2_prefix (str | None): prefix for sentence2. Defaults to None.
+ encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}.
"""
def __init__(
@@ -35,6 +36,7 @@ def __init__(
sentence1_prefix: str | None = None,
sentence2_prefix: str | None = None,
log_predictions: bool = False,
+ encode_kwargs: dict = {},
) -> None:
self.val_dataset = val_dataset
self.test_dataset = test_dataset
@@ -42,6 +44,7 @@ def __init__(
self.sentence2_prefix = sentence2_prefix
self.main_metric = "spearman"
self.log_predictions = log_predictions
+ self.encode_kwargs = encode_kwargs
def __call__(
self, model: TextEmbedder, cache_dir: str | PathLike[str] | None = None, overwrite_cache: bool = False
@@ -149,12 +152,14 @@ def _convert_to_embeddings(
prefix=self.sentence1_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
embeddings2 = model.batch_encode_with_cache(
[item.sentence2 for item in dataset],
prefix=self.sentence2_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
+ **self.encode_kwargs,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
embeddings1 = convert_to_tensor(embeddings1, device)
diff --git a/tests/embedders/test_dp_sbert.py b/tests/embedders/test_dp_sbert.py
index 028e240..e1cd4d2 100644
--- a/tests/embedders/test_dp_sbert.py
+++ b/tests/embedders/test_dp_sbert.py
@@ -18,6 +18,11 @@ def test_encode(self):
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)
+ def test_encode_with_prompt(self):
+ embeddings = self.model.encode("任意のテキスト", prefix="プロンプト")
+ assert isinstance(embeddings, np.ndarray)
+ assert embeddings.shape == (OUTPUT_DIM,)
+
def test_get_output_dim(self):
assert self.model.get_output_dim() == OUTPUT_DIM
diff --git a/tests/test_main.py b/tests/test_main.py
index ee81fb5..40de1a3 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -31,5 +31,5 @@ def test_main_cli():
# fmt: on
result = subprocess.run(command)
assert result.returncode == 0
-
+ assert (Path(f) / "jmteb_config.yaml").exists()
assert (Path(f) / "summary.json").exists()