|
1 |
| -from abc import ABC |
| 1 | +from abc import ABC, abstractmethod |
2 | 2 | from typing import Any
|
3 | 3 |
|
4 |
| -from langchain.embeddings import HuggingFaceEmbeddings |
5 |
| - |
6 | 4 |
|
7 | 5 | class Embedder(ABC):
|
8 |
| - embedder: Any |
| 6 | + @abstractmethod |
| 7 | + def embed_documents(self, texts: list[str]) -> list[list[float]]: |
| 8 | + """Embed search docs.""" |
| 9 | + |
| 10 | + @abstractmethod |
| 11 | + def embed_query(self, text: str) -> list[float]: |
| 12 | + """Embed query text.""" |
| 13 | + |
| 14 | + |
| 15 | +class HuggingFaceEmbedder(Embedder): |
| 16 | + """HuggingFace sentence_transformers embedding models. |
| 17 | +
|
| 18 | + To use, you should have the ``sentence_transformers`` python package installed. |
| 19 | + """ |
| 20 | + |
| 21 | + client: Any #: :meta private: |
| 22 | + model_name: str = "all-MiniLM-L6-v2" |
| 23 | + """Model name to use.""" |
| 24 | + cache_folder: str | None = None |
| 25 | + """Path to store models. |
| 26 | + Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" |
| 27 | + model_kwargs: dict[str, Any] = {} |
| 28 | + """Keyword arguments to pass to the model.""" |
| 29 | + encode_kwargs: dict[str, Any] = {} |
| 30 | + """Keyword arguments to pass when calling the `encode` method of the model.""" |
| 31 | + multi_process: bool = False |
| 32 | + """Run encode() on multiple GPUs.""" |
| 33 | + |
| 34 | + def __init__(self, **kwargs: Any): |
| 35 | + """Initialize the sentence_transformer.""" |
| 36 | + super().__init__(**kwargs) |
| 37 | + try: |
| 38 | + import sentence_transformers |
| 39 | + |
| 40 | + except ImportError as exc: |
| 41 | + raise ImportError( |
| 42 | + "Could not import sentence_transformers python package. " |
| 43 | + "Please install it with `pip install sentence-transformers`." |
| 44 | + ) from exc |
| 45 | + |
| 46 | + self.client = sentence_transformers.SentenceTransformer( |
| 47 | + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs |
| 48 | + ) |
| 49 | + |
| 50 | + def embed_documents(self, texts: list[str]) -> list[list[float]]: |
| 51 | + """Compute doc embeddings using a HuggingFace transformer model. |
| 52 | +
|
| 53 | + Args: |
| 54 | + texts: The list of texts to embed. |
| 55 | +
|
| 56 | + Returns: |
| 57 | + List of embeddings, one for each text. |
| 58 | + """ |
| 59 | + import sentence_transformers |
| 60 | + |
| 61 | + texts = list(map(lambda x: x.replace("\n", " "), texts)) |
| 62 | + if self.multi_process: |
| 63 | + pool = self.client.start_multi_process_pool() |
| 64 | + embeddings = self.client.encode_multi_process(texts, pool) |
| 65 | + sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) |
| 66 | + else: |
| 67 | + embeddings = self.client.encode(texts, **self.encode_kwargs) |
| 68 | + |
| 69 | + return embeddings.tolist() |
9 | 70 |
|
10 |
| - def get_embedding(self): |
11 |
| - return self.embedder |
| 71 | + def embed_query(self, text: str) -> list[float]: |
| 72 | + """Compute query embeddings using a HuggingFace transformer model. |
12 | 73 |
|
| 74 | + Args: |
| 75 | + text: The text to embed. |
13 | 76 |
|
14 |
| -class EmbedderHuggingFace(Embedder): |
15 |
| - def __init__(self, model_name: str = "all-MiniLM-L6-v2"): |
16 |
| - self.embedder = HuggingFaceEmbeddings(model_name=model_name) |
| 77 | + Returns: |
| 78 | + Embeddings for the text. |
| 79 | + """ |
| 80 | + return self.embed_documents([text])[0] |
0 commit comments