-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Azure embedders support (#6676)
* Add Azure embedders --------- Co-authored-by: Stefano Fiorucci <[email protected]>
- Loading branch information
Showing
6 changed files
with
440 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
haystack/components/embedders/azure_document_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import os | ||
from typing import List, Optional, Dict, Any, Tuple | ||
|
||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI | ||
from tqdm import tqdm | ||
|
||
from haystack import component, Document, default_to_dict | ||
|
||
|
||
@component | ||
class AzureOpenAIDocumentEmbedder: | ||
""" | ||
A component for computing Document embeddings using OpenAI models. | ||
The embedding of each Document is stored in the `embedding` field of the Document. | ||
Usage example: | ||
```python | ||
from haystack import Document | ||
from haystack.components.embedders import AzureOpenAIDocumentEmbedder | ||
doc = Document(content="I love pizza!") | ||
document_embedder = AzureOpenAIDocumentEmbedder() | ||
result = document_embedder.run([doc]) | ||
print(result['documents'][0].embedding) | ||
# [0.017020374536514282, -0.023255806416273117, ...] | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
azure_endpoint: Optional[str] = None, | ||
api_version: Optional[str] = "2023-05-15", | ||
azure_deployment: str = "text-embedding-ada-002", | ||
api_key: Optional[str] = None, | ||
azure_ad_token: Optional[str] = None, | ||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None, | ||
organization: Optional[str] = None, | ||
prefix: str = "", | ||
suffix: str = "", | ||
batch_size: int = 32, | ||
progress_bar: bool = True, | ||
meta_fields_to_embed: Optional[List[str]] = None, | ||
embedding_separator: str = "\n", | ||
): | ||
""" | ||
Create an AzureOpenAITextEmbedder component. | ||
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/` | ||
:param api_version: The version of the API to use. Defaults to 2023-05-15 | ||
:param azure_deployment: The deployment of the model, usually the model name. | ||
:param api_key: The API key to use for authentication. | ||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id | ||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked | ||
on every request. | ||
:param organization: The Organization ID, defaults to `None`. See | ||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). | ||
:param prefix: A string to add to the beginning of each text. | ||
:param suffix: A string to add to the end of each text. | ||
:param batch_size: Number of Documents to encode at once. | ||
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments | ||
to keep the logs clean. | ||
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. | ||
:param embedding_separator: Separator used to concatenate the meta fields to the Document text. | ||
""" | ||
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT | ||
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") | ||
if not azure_endpoint: | ||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") | ||
|
||
self.api_version = api_version | ||
self.azure_endpoint = azure_endpoint | ||
self.azure_deployment = azure_deployment | ||
self.organization = organization | ||
self.prefix = prefix | ||
self.suffix = suffix | ||
self.batch_size = batch_size | ||
self.progress_bar = progress_bar | ||
self.meta_fields_to_embed = meta_fields_to_embed or [] | ||
self.embedding_separator = embedding_separator | ||
|
||
self._client = AzureOpenAI( | ||
api_version=api_version, | ||
azure_endpoint=azure_endpoint, | ||
azure_deployment=azure_deployment, | ||
api_key=api_key, | ||
azure_ad_token=azure_ad_token, | ||
azure_ad_token_provider=azure_ad_token_provider, | ||
organization=organization, | ||
) | ||
|
||
def _get_telemetry_data(self) -> Dict[str, Any]: | ||
""" | ||
Data that is sent to Posthog for usage analytics. | ||
""" | ||
return {"model": self.azure_deployment} | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed | ||
to the constructor. | ||
""" | ||
return default_to_dict( | ||
self, | ||
azure_endpoint=self.azure_endpoint, | ||
azure_deployment=self.azure_deployment, | ||
organization=self.organization, | ||
api_version=self.api_version, | ||
prefix=self.prefix, | ||
suffix=self.suffix, | ||
batch_size=self.batch_size, | ||
progress_bar=self.progress_bar, | ||
meta_fields_to_embed=self.meta_fields_to_embed, | ||
embedding_separator=self.embedding_separator, | ||
) | ||
|
||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: | ||
""" | ||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. | ||
""" | ||
texts_to_embed = [] | ||
for doc in documents: | ||
meta_values_to_embed = [ | ||
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None | ||
] | ||
|
||
text_to_embed = ( | ||
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix | ||
).replace("\n", " ") | ||
|
||
texts_to_embed.append(text_to_embed) | ||
return texts_to_embed | ||
|
||
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: | ||
""" | ||
Embed a list of texts in batches. | ||
""" | ||
|
||
all_embeddings: List[List[float]] = [] | ||
meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}} | ||
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Texts"): | ||
batch = texts_to_embed[i : i + batch_size] | ||
response = self._client.embeddings.create(model=self.azure_deployment, input=batch) | ||
|
||
# Append embeddings to the list | ||
all_embeddings.extend(el.embedding for el in response.data) | ||
|
||
# Update the meta information only once if it's empty | ||
if not meta["model"]: | ||
meta["model"] = response.model | ||
meta["usage"] = dict(response.usage) | ||
else: | ||
# Update the usage tokens | ||
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens | ||
meta["usage"]["total_tokens"] += response.usage.total_tokens | ||
|
||
return all_embeddings, meta | ||
|
||
@component.output_types(documents=List[Document], meta=Dict[str, Any]) | ||
def run(self, documents: List[Document]): | ||
""" | ||
Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document. | ||
:param documents: A list of Documents to embed. | ||
""" | ||
if not (isinstance(documents, list) and all(isinstance(doc, Document) for doc in documents)): | ||
raise TypeError("Input must be a list of Document instances. For strings, use AzureOpenAITextEmbedder.") | ||
|
||
texts_to_embed = self._prepare_texts_to_embed(documents=documents) | ||
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) | ||
|
||
# Assign the corresponding embeddings to each document | ||
for doc, emb in zip(documents, embeddings): | ||
doc.embedding = emb | ||
|
||
return {"documents": documents, "meta": meta} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import os | ||
from typing import List, Optional, Dict, Any | ||
|
||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI | ||
|
||
from haystack import component, default_to_dict, Document | ||
|
||
|
||
@component | ||
class AzureOpenAITextEmbedder: | ||
""" | ||
A component for embedding strings using OpenAI models. | ||
Usage example: | ||
```python | ||
from haystack.components.embedders import AzureOpenAITextEmbedder | ||
text_to_embed = "I love pizza!" | ||
text_embedder = AzureOpenAITextEmbedder() | ||
print(text_embedder.run(text_to_embed)) | ||
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...], | ||
# 'meta': {'model': 'text-embedding-ada-002-v2', | ||
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
azure_endpoint: Optional[str] = None, | ||
api_version: Optional[str] = "2023-05-15", | ||
azure_deployment: str = "text-embedding-ada-002", | ||
api_key: Optional[str] = None, | ||
azure_ad_token: Optional[str] = None, | ||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None, | ||
organization: Optional[str] = None, | ||
prefix: str = "", | ||
suffix: str = "", | ||
): | ||
""" | ||
Create an AzureOpenAITextEmbedder component. | ||
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/` | ||
:param api_version: The version of the API to use. Defaults to 2023-05-15 | ||
:param azure_deployment: The deployment of the model, usually the model name. | ||
:param api_key: The API key to use for authentication. | ||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id | ||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked | ||
on every request. | ||
:param organization: The Organization ID, defaults to `None`. See | ||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). | ||
:param prefix: A string to add to the beginning of each text. | ||
:param suffix: A string to add to the end of each text. | ||
""" | ||
# Why is this here? | ||
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not | ||
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead | ||
# of passing it as a parameter. | ||
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") | ||
if not azure_endpoint: | ||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") | ||
|
||
self.api_version = api_version | ||
self.azure_endpoint = azure_endpoint | ||
self.azure_deployment = azure_deployment | ||
self.organization = organization | ||
self.prefix = prefix | ||
self.suffix = suffix | ||
|
||
self._client = AzureOpenAI( | ||
api_version=api_version, | ||
azure_endpoint=azure_endpoint, | ||
azure_deployment=azure_deployment, | ||
api_key=api_key, | ||
azure_ad_token=azure_ad_token, | ||
azure_ad_token_provider=azure_ad_token_provider, | ||
organization=organization, | ||
) | ||
|
||
def _get_telemetry_data(self) -> Dict[str, Any]: | ||
""" | ||
Data that is sent to Posthog for usage analytics. | ||
""" | ||
return {"model": self.azure_deployment} | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
:return: The serialized component as a dictionary. | ||
""" | ||
return default_to_dict( | ||
self, | ||
azure_endpoint=self.azure_endpoint, | ||
azure_deployment=self.azure_deployment, | ||
organization=self.organization, | ||
api_version=self.api_version, | ||
prefix=self.prefix, | ||
suffix=self.suffix, | ||
) | ||
|
||
@component.output_types(embedding=List[float], meta=Dict[str, Any]) | ||
def run(self, text: str): | ||
"""Embed a string using AzureOpenAITextEmbedder.""" | ||
if not isinstance(text, str): | ||
# Check if input is a list and all elements are instances of Document | ||
if isinstance(text, list) and all(isinstance(elem, Document) for elem in text): | ||
error_message = "Input must be a string. Use AzureOpenAIDocumentEmbedder for a list of Documents." | ||
else: | ||
error_message = "Input must be a string." | ||
raise TypeError(error_message) | ||
|
||
# Preprocess the text by adding prefixes/suffixes | ||
# finally, replace newlines as recommended by OpenAI docs | ||
processed_text = f"{self.prefix}{text}{self.suffix}".replace("\n", " ") | ||
|
||
response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text) | ||
|
||
return { | ||
"embedding": response.data[0].embedding, | ||
"meta": {"model": response.model, "usage": dict(response.usage)}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
features: | ||
- | | ||
Adds AzureOpenAIDocumentEmbedder and AzureOpenAITextEmbedder as new embedders. These embedders are very similar to | ||
their OpenAI counterparts, but they use the Azure API instead of the OpenAI API. |
Oops, something went wrong.