Skip to content

Commit

Permalink
Modify base embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Sep 27, 2024
1 parent e68a27f commit 82d37a3
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/jmteb/embedders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TextEmbedder(ABC):
convert_to_numpy: bool
_chunk_size: int = 262144 # 2^18

def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray | torch.Tensor:
def encode(self, text: str | list[str], prefix: str | None = None, **kwargs) -> np.ndarray | torch.Tensor:
"""Convert a text string or a list of texts to embedding.
Args:
Expand All @@ -43,6 +43,7 @@ def _batch_encode_and_save_on_disk(
prefix: str | None = None,
batch_size: int = 262144,
dtype: str = "float32",
**kwargs,
) -> np.memmap | torch.Tensor:
"""
Encode a list of texts and save the embeddings on disk using memmap.
Expand All @@ -65,7 +66,7 @@ def _batch_encode_and_save_on_disk(
with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar:
for i in range(0, num_samples, batch_size):
batch = text_list[i : i + batch_size]
batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix)
batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix, **kwargs)
embeddings[i : i + batch_size] = batch_embeddings
pbar.update(len(batch))

Expand All @@ -83,6 +84,7 @@ def batch_encode_with_cache(
cache_path: str | PathLike[str] | None = None,
overwrite_cache: bool = False,
dtype: str = "float32",
**kwargs,
) -> np.ndarray | torch.Tensor:
"""
Encode a list of texts and save the embeddings on disk using memmap if cache_path is provided.
Expand All @@ -95,17 +97,18 @@ def batch_encode_with_cache(
dtype (str, optional): data type. Defaults to "float32".
"""

logger.warning(f"{kwargs=}")
if cache_path is None:
logger.info("Encoding embeddings")
return self.encode(text_list, prefix=prefix)
return self.encode(text_list, prefix=prefix, **kwargs)

if Path(cache_path).exists() and not overwrite_cache:
logger.info(f"Loading embeddings from {cache_path}")
return np.memmap(cache_path, dtype=dtype, mode="r", shape=(len(text_list), self.get_output_dim()))

logger.info(f"Encoding and saving embeddings to {cache_path}")
embeddings = self._batch_encode_and_save_on_disk(
text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype
text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype, **kwargs
)
return embeddings

Expand Down

0 comments on commit 82d37a3

Please sign in to comment.