Skip to content

Commit

Permalink
Merge pull request #80 from sbintuitions/feature/encode_args
Browse files Browse the repository at this point in the history
[Feature] Allow setting encode kwargs in SentenceBert embedder
  • Loading branch information
lsz05 authored Nov 27, 2024
2 parents b4d0df2 + 94010c7 commit 0ded6ad
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 5 deletions.
12 changes: 8 additions & 4 deletions src/jmteb/embedders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ class TextEmbedder(ABC):
convert_to_numpy: bool
_chunk_size: int = 262144 # 2^18

def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray | torch.Tensor:
def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray | torch.Tensor:
"""Convert a text string or a list of texts to embedding.
Args:
text (str | list[str]): text string, or a list of texts.
prefix (str, optional): the prefix to use for encoding. Default to None.
**kwargs: some more settings that may be necessary for specific models.
"""
raise NotImplementedError

Expand All @@ -43,6 +44,7 @@ def _batch_encode_and_save_on_disk(
prefix: str | None = None,
batch_size: int = 262144,
dtype: str = "float32",
**kwargs,
) -> np.memmap | torch.Tensor:
"""
Encode a list of texts and save the embeddings on disk using memmap.
Expand All @@ -65,7 +67,7 @@ def _batch_encode_and_save_on_disk(
with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar:
for i in range(0, num_samples, batch_size):
batch = text_list[i : i + batch_size]
batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix)
batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix, **kwargs)
embeddings[i : i + batch_size] = batch_embeddings
pbar.update(len(batch))

Expand All @@ -83,6 +85,7 @@ def batch_encode_with_cache(
cache_path: str | PathLike[str] | None = None,
overwrite_cache: bool = False,
dtype: str = "float32",
**kwargs,
) -> np.ndarray | torch.Tensor:
"""
Encode a list of texts and save the embeddings on disk using memmap if cache_path is provided.
Expand All @@ -95,17 +98,18 @@ def batch_encode_with_cache(
dtype (str, optional): data type. Defaults to "float32".
"""

logger.warning(f"{kwargs=}")
if cache_path is None:
logger.info("Encoding embeddings")
return self.encode(text_list, prefix=prefix)
return self.encode(text_list, prefix=prefix, **kwargs)

if Path(cache_path).exists() and not overwrite_cache:
logger.info(f"Loading embeddings from {cache_path}")
return np.memmap(cache_path, dtype=dtype, mode="r", shape=(len(text_list), self.get_output_dim()))

logger.info(f"Encoding and saving embeddings to {cache_path}")
embeddings = self._batch_encode_and_save_on_disk(
text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype
text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype, **kwargs
)
return embeddings

Expand Down
3 changes: 2 additions & 1 deletion src/jmteb/embedders/sbert_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
else:
self.set_output_numpy()

def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray:
if self.add_eos:
text = self._add_eos_func(text)
return self.model.encode(
Expand All @@ -54,6 +54,7 @@ def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray
batch_size=self.batch_size,
device=self.device,
normalize_embeddings=self.normalize_embeddings,
**kwargs,
)

def _add_eos_func(self, text: str | list[str]) -> str | list[str]:
Expand Down
6 changes: 6 additions & 0 deletions src/jmteb/evaluators/classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ClassificationEvaluator(EmbeddingEvaluator):
classifiers (dict[str, Classifier]): classifiers to be evaluated.
prefix (str | None): prefix for sentences. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint.
encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}.
"""

def __init__(
Expand All @@ -40,6 +41,7 @@ def __init__(
classifiers: dict[str, Classifier] | None = None,
prefix: str | None = None,
log_predictions: bool = False,
encode_kwargs: dict = {},
) -> None:
self.train_dataset = train_dataset
self.val_dataset = val_dataset
Expand All @@ -55,6 +57,7 @@ def __init__(
] or ["macro"]
self.prefix = prefix
self.log_predictions = log_predictions
self.encode_kwargs = encode_kwargs
self.main_metric = f"{self.average[0]}_f1"

def __call__(
Expand All @@ -69,6 +72,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "train_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
y_train = [item.label for item in self.train_dataset]

Expand All @@ -77,6 +81,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
y_val = [item.label for item in self.val_dataset]

Expand All @@ -90,6 +95,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
y_test = [item.label for item in self.test_dataset]

Expand Down
12 changes: 12 additions & 0 deletions src/jmteb/evaluators/clustering/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
class ClusteringEvaluator(EmbeddingEvaluator):
"""
ClusteringEvaluator is a class for evaluating clustering models.
Args:
val_dataset (ClusteringDataset): validation dataset
test_dataset (ClusteringDataset): evaluation dataset
prefix (str | None): prefix for sentences. Defaults to None.
random_seed (int | None): random seed used in clustering models. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint.
encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}.
"""

def __init__(
Expand All @@ -33,12 +41,14 @@ def __init__(
prefix: str | None = None,
random_seed: int | None = None,
log_predictions: bool = False,
encode_kwargs: dict = {},
) -> None:
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.prefix = prefix
self.random_seed = random_seed
self.log_predictions = log_predictions
self.encode_kwargs = encode_kwargs
self.main_metric = "v_measure_score"

def __call__(
Expand All @@ -53,6 +63,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
val_labels = [item.label for item in self.val_dataset]

Expand All @@ -66,6 +77,7 @@ def __call__(
prefix=self.prefix,
cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
test_labels = [item.label for item in self.test_dataset]

Expand Down
5 changes: 5 additions & 0 deletions src/jmteb/evaluators/pair_classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class PairClassificationEvaluator(EmbeddingEvaluator):
test_dataset (PairClassificationDataset): test dataset
sentence1_prefix (str | None): prefix for sentence1. Defaults to None.
sentence2_prefix (str | None): prefix for sentence2. Defaults to None.
encode_kwargs (dict): kwargs passed to embedder's encode function. Default to {}.
# NOTE: Don't log predictions, as predictions by different metrics could be different.
"""
Expand All @@ -32,11 +33,13 @@ def __init__(
test_dataset: PairClassificationDataset,
sentence1_prefix: str | None = None,
sentence2_prefix: str | None = None,
encode_kwargs: dict = {},
) -> None:
self.test_dataset = test_dataset
self.val_dataset = val_dataset
self.sentence1_prefix = sentence1_prefix
self.sentence2_prefix = sentence2_prefix
self.encode_kwargs = encode_kwargs
self.metrics = [ThresholdAccuracyMetric(), ThresholdF1Metric()]
self.main_metric = "binary_f1"

Expand Down Expand Up @@ -122,12 +125,14 @@ def _convert_to_embeddings(
prefix=self.sentence1_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
embeddings2 = model.batch_encode_with_cache(
[item.sentence2 for item in dataset],
prefix=self.sentence2_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
golden_labels = [item.label for item in dataset]
return embeddings1, embeddings2, golden_labels
9 changes: 9 additions & 0 deletions src/jmteb/evaluators/reranking/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class RerankingEvaluator(EmbeddingEvaluator):
doc_prefix (str | None): prefix for documents. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint. Defaults to False.
top_n_docs_to_log (int): log only top n documents. Defaults to 5.
query_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding queries. Defaults to {}.
doc_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding documents. Defaults to {}.
"""

def __init__(
Expand All @@ -51,6 +53,8 @@ def __init__(
doc_prefix: str | None = None,
log_predictions: bool = False,
top_n_docs_to_log: int = 5,
query_encode_kwargs: dict = {},
doc_encode_kwargs: dict = {},
) -> None:
self.test_query_dataset = test_query_dataset
self.val_query_dataset = val_query_dataset
Expand All @@ -61,6 +65,8 @@ def __init__(
self.doc_prefix = doc_prefix
self.log_predictions = log_predictions
self.top_n_docs_to_log = top_n_docs_to_log
self.query_encode_kwargs = query_encode_kwargs
self.doc_encode_kwargs = doc_encode_kwargs

def __call__(
self,
Expand All @@ -77,6 +83,7 @@ def __call__(
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.query_encode_kwargs,
)
if self.val_query_dataset == self.test_query_dataset:
test_query_embeddings = val_query_embeddings
Expand All @@ -86,12 +93,14 @@ def __call__(
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.query_encode_kwargs,
)
doc_embeddings = model.batch_encode_with_cache(
text_list=[item.text for item in self.doc_dataset],
prefix=self.doc_prefix,
cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.doc_encode_kwargs,
)

logger.info("Start reranking")
Expand Down
9 changes: 9 additions & 0 deletions src/jmteb/evaluators/retrieval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class RetrievalEvaluator(EmbeddingEvaluator):
doc_prefix (str | None): prefix for documents. Defaults to None.
log_predictions (bool): whether to log predictions of each datapoint. Defaults to False.
top_n_docs_to_log (int): log only top n documents that are predicted as relevant. Defaults to 5.
query_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding queries. Defaults to {}.
doc_encode_kwargs (dict): kwargs passed to embedder's encode function when encoding documents. Defaults to {}.
"""

def __init__(
Expand All @@ -56,6 +58,8 @@ def __init__(
doc_prefix: str | None = None,
log_predictions: bool = False,
top_n_docs_to_log: int = 5,
query_encode_kwargs: dict = {},
doc_encode_kwargs: dict = {},
) -> None:
self.val_query_dataset = val_query_dataset
self.test_query_dataset = test_query_dataset
Expand All @@ -72,6 +76,8 @@ def __init__(
self.doc_prefix = doc_prefix
self.log_predictions = log_predictions
self.top_n_docs_to_log = top_n_docs_to_log
self.query_encode_kwargs = query_encode_kwargs
self.doc_encode_kwargs = doc_encode_kwargs

def __call__(
self,
Expand All @@ -88,6 +94,7 @@ def __call__(
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.query_encode_kwargs,
)
if self.val_query_dataset == self.test_query_dataset:
test_query_embeddings = val_query_embeddings
Expand All @@ -97,13 +104,15 @@ def __call__(
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.query_encode_kwargs,
)

doc_embeddings = model.batch_encode_with_cache(
text_list=[item.text for item in self.doc_dataset],
prefix=self.doc_prefix,
cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.doc_encode_kwargs,
)

logger.info("Start retrieval")
Expand Down
5 changes: 5 additions & 0 deletions src/jmteb/evaluators/sts/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class STSEvaluator(EmbeddingEvaluator):
test_dataset (STSDataset): test dataset
sentence1_prefix (str | None): prefix for sentence1. Defaults to None.
sentence2_prefix (str | None): prefix for sentence2. Defaults to None.
encode_kwargs (dict): kwargs passed to embedder's encode function. Defaults to {}.
"""

def __init__(
Expand All @@ -35,13 +36,15 @@ def __init__(
sentence1_prefix: str | None = None,
sentence2_prefix: str | None = None,
log_predictions: bool = False,
encode_kwargs: dict = {},
) -> None:
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.sentence1_prefix = sentence1_prefix
self.sentence2_prefix = sentence2_prefix
self.main_metric = "spearman"
self.log_predictions = log_predictions
self.encode_kwargs = encode_kwargs

def __call__(
self, model: TextEmbedder, cache_dir: str | PathLike[str] | None = None, overwrite_cache: bool = False
Expand Down Expand Up @@ -149,12 +152,14 @@ def _convert_to_embeddings(
prefix=self.sentence1_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
embeddings2 = model.batch_encode_with_cache(
[item.sentence2 for item in dataset],
prefix=self.sentence2_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
**self.encode_kwargs,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
embeddings1 = convert_to_tensor(embeddings1, device)
Expand Down

0 comments on commit 0ded6ad

Please sign in to comment.