Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supported multi-modal embeddings from Azure AI Vision service #162

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ install PyCharm>=2023.2 with [built-in Black support](https://blog.jetbrains.com

## Run

Run the development server:
Run the development server locally:

```sh
make serve
```

Run the server from Docker container:

```sh
make docker_serve
```

### Make on Windows

As of now, Windows distributions do not include the make tool. To run make commands, the tool can be installed using
Expand All @@ -52,36 +58,40 @@ The command definitions inside Makefile should be cross-platform to keep the dev

## Environment Variables

Copy `.env.example` to `.env` and customize it for your environment:
Copy `.env.example` to `.env` and customize it for your environment.

### Categories of deployments

The following variables cluster all deployments into the groups of deployments which share the same API.

|Variable|Default|Description|
|---|---|---|
|LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod|
|WEB_CONCURRENCY|1|Number of workers for the server|
|MODEL_ALIASES|`{}`|Mapping request's deployment_id to [model name of tiktoken](https://github.com/openai/tiktoken/blob/main/tiktoken/model.py) for correct calculate of tokens. Example: `{"gpt-35-turbo":"gpt-3.5-turbo-0301"}`|
|DIAL_USE_FILE_STORAGE|False|Save image model artifacts to DIAL File storage (DALL-E images are uploaded to the files storage and its base64 encodings are replaced with links to the storage)|
|DIAL_URL||URL of the core DIAL server (required when DIAL_USE_FILE_STORAGE=True)|
|DALLE3_DEPLOYMENTS|``|Comma-separated list of deployments that support DALL-E 3 API. Example: `dall-e-3,dalle3,dall-e`|
|DALLE3_AZURE_API_VERSION|2024-02-01|The version API for requests to Azure DALL-E-3 API|
|GPT4_VISION_DEPLOYMENTS|``|Comma-separated list of deployments that support GPT-4V API. Example: `gpt-4-vision-preview,gpt-4-vision`|
|GPT4_VISION_MAX_TOKENS|1024|Default value of `max_tokens` parameter for GPT-4V when it wasn't provided in the request|
|ACCESS_TOKEN_EXPIRATION_WINDOW|10|The Azure access token is renewed this many seconds before its actual expiration time. The buffer ensures that the token does not expire in the middle of an operation due to processing time and potential network delays.|
|AZURE_OPEN_AI_SCOPE|https://cognitiveservices.azure.com/.default|Provided scope of access token to Azure OpenAI services|
|API_VERSIONS_MAPPING|`{}`|The mapping of versions API for requests to Azure OpenAI API. Example: `{"2023-03-15-preview": "2023-05-15", "": "2024-02-15-preview"}`. An empty key sets the default api version for the case when the user didn't pass it in the request|
|DALLE3_AZURE_API_VERSION|2024-02-01|The version API for requests to Azure DALL-E-3 API|
|ELIMINATE_EMPTY_CHOICES|False|When enabled, the response stream is guaranteed to exclude chunks with an empty list of choices. This is useful when a DIAL client doesn't support such chunks. An empty list of choices can be generated by Azure OpenAI in at least two cases: (1) when the **Content filter** is not disabled, Azure includes [prompt filter results](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter?tabs=warning%2Cuser-prompt%2Cpython-new#prompt-annotation-message) in the first chunk with an empty list of choices; (2) when `stream_options.include_usage` is enabled, the last chunk contains usage data and an empty list of choices. This variable replaces the deprecated `FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS` which served the same function.|
|CORE_API_VERSION||Supported value `0.6` to work with the old version of the DIAL File API|
|MISTRAL_DEPLOYMENTS|``|Comma-separated list of deployments that support Mistral Large Azure API. Example: `mistral-large-azure,mistral-large`|
|DATABRICKS_DEPLOYMENTS|``|Comma-separated list of Databricks chat completion deployments. Example: `databricks-dbrx-instruct,databricks-mixtral-8x7b-instruct,databricks-llama-2-70b-chat`|
|GPT4O_DEPLOYMENTS|``|Comma-separated list of GPT-4o chat completion deployments. Example: `gpt-4o-2024-05-13`|
|NON_STREAMING_DEPLOYMENTS|``|Comma-separated list of deployments which do not support streaming. The adapter is going to emulate the streaming by calling the model and converting its response into a single-chunk stream. Example: `o1-mini`, `o1-preview`|
|AZURE_AI_VISION_DEPLOYMENTS|``|Comma-separated list of Azure AI Vision embedding deployments. The endpoint of the deployment is expected point to the Azure service: `https://<service-name>.cognitiveservices.azure.com/`|

### Docker
Deployments that do not fall into any of the categories are considered to support text-to-text chat completion OpenAI API or text embeddings OpenAI API.

Run the server in Docker:
### Other variables

```sh
make docker_serve
```
|Variable|Default|Description|
|---|---|---|
|LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod|
|WEB_CONCURRENCY|1|Number of workers for the server|
|MODEL_ALIASES|`{}`|Mapping from the request deployment id to [model name of tiktoken](https://github.com/openai/tiktoken/blob/main/tiktoken/model.py). Required for the token calculation on the adapter side. Example: `{"my-gpt-deployment":"gpt-3.5-turbo-0301"}`|
|DIAL_USE_FILE_STORAGE|False|Save image model artifacts to DIAL File storage (DALL-E images are uploaded to the DIAL file storage and its base64 encodings are replaced with links to the storage)|
|DIAL_URL||URL of the core DIAL server (required when DIAL_USE_FILE_STORAGE=True)|
|NON_STREAMING_DEPLOYMENTS|``|Comma-separated list of deployments which do not support streaming. The adapter is going to emulate the streaming by calling the model and converting its response into a single-chunk stream. Example: `o1-mini`, `o1-preview`|
|ACCESS_TOKEN_EXPIRATION_WINDOW|10|The Azure access token is renewed this many seconds before its actual expiration time. The buffer ensures that the token does not expire in the middle of an operation due to processing time and potential network delays.|
|AZURE_OPEN_AI_SCOPE|https://cognitiveservices.azure.com/.default|Provided scope of access token to Azure OpenAI services|
|API_VERSIONS_MAPPING|`{}`|The mapping of versions API for requests to Azure OpenAI API. Example: `{"2023-03-15-preview": "2023-05-15", "": "2024-02-15-preview"}`. An empty key sets the default api version for the case when the user didn't pass it in the request|
|ELIMINATE_EMPTY_CHOICES|False|When enabled, the response stream is guaranteed to exclude chunks with an empty list of choices. This is useful when a DIAL client doesn't support such chunks. An empty list of choices can be generated by Azure OpenAI in at least two cases: (1) when the **Content filter** is not disabled, Azure includes [prompt filter results](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter?tabs=warning%2Cuser-prompt%2Cpython-new#prompt-annotation-message) in the first chunk with an empty list of choices; (2) when `stream_options.include_usage` is enabled, the last chunk contains usage data and an empty list of choices. This variable replaces the deprecated `FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS` which served the same function.|
|CORE_API_VERSION||Supported value `0.6` to work with the old version of the DIAL File API|

## Lint

Expand Down
31 changes: 21 additions & 10 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import asynccontextmanager

import pydantic
from aidial_sdk._errors import pydantic_validation_exception_handler
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.telemetry.init import init_telemetry
Expand All @@ -22,8 +24,15 @@
chat_completion as databricks_chat_completion,
)
from aidial_adapter_openai.dial_api.storage import create_file_storage
from aidial_adapter_openai.embeddings.azure_ai_vision import (
embeddings as azure_ai_vision_embeddings,
)
from aidial_adapter_openai.embeddings.openai import (
embeddings as openai_embeddings,
)
from aidial_adapter_openai.env import (
API_VERSIONS_MAPPING,
AZURE_AI_VISION_DEPLOYMENTS,
DALLE3_AZURE_API_VERSION,
DALLE3_DEPLOYMENTS,
DATABRICKS_DEPLOYMENTS,
Expand All @@ -44,12 +53,7 @@
from aidial_adapter_openai.utils.auth import get_credentials
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.parsers import (
completions_parser,
embeddings_parser,
parse_body,
)
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.parsers import completions_parser, parse_body
from aidial_adapter_openai.utils.streaming import create_server_response
from aidial_adapter_openai.utils.tokenizer import (
MultiModalTokenizer,
Expand Down Expand Up @@ -193,11 +197,13 @@ async def embedding(deployment_id: str, request: Request):
api_version = get_api_version(request)
upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]

client = embeddings_parser.parse(upstream_endpoint).get_client(
{**creds, "api_version": api_version}
)
if deployment_id in AZURE_AI_VISION_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
return await azure_ai_vision_embeddings(
creds, deployment_id, upstream_endpoint, storage, data
)

return await call_with_extra_body(client.embeddings.create, data)
return await openai_embeddings(creds, upstream_endpoint, api_version, data)


@app.exception_handler(OpenAIError)
Expand Down Expand Up @@ -237,6 +243,11 @@ def openai_exception_handler(request: Request, e: DialException):
)


@app.exception_handler(pydantic.ValidationError)
def pydantic_exception_handler(request: Request, exc: pydantic.ValidationError):
return pydantic_validation_exception_handler(request, exc)


@app.exception_handler(DialException)
def dial_exception_handler(request: Request, exc: DialException):
return exc.to_fastapi_response()
Expand Down
86 changes: 86 additions & 0 deletions aidial_adapter_openai/dial_api/embedding_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
List,
TypeVar,
assert_never,
cast,
)

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.embeddings.request import EmbeddingsRequest
from aidial_sdk.exceptions import RequestValidationError

_T = TypeVar("_T")

_Coro = Coroutine[_T, Any, Any]
_Tokens = List[int]


async def reject_tokens(tokens: _Tokens):
raise RequestValidationError(
"Tokens in an embedding input are not supported. Provide text instead. "
"When Langchain AzureOpenAIEmbeddings class is used, set 'check_embedding_ctx_length=False' to disable tokenization."
)


async def reject_mixed(input: List[str | Attachment]):
raise RequestValidationError(
"Embedding inputs composed of multiple texts and/or attachments aren't supported"
)


async def collect_embedding_inputs(
request: EmbeddingsRequest,
*,
on_text: Callable[[str], _Coro[_T]],
on_attachment: Callable[[Attachment], _Coro[_T]],
on_tokens: Callable[[_Tokens], _Coro[_T]] = reject_tokens,
on_mixed: Callable[[List[str | Attachment]], _Coro[_T]] = reject_mixed,
) -> AsyncIterator[_T]:

async def _on_str_or_attachment(input: str | Attachment) -> _T:
if isinstance(input, str):
return await on_text(input)
elif isinstance(input, Attachment):
return await on_attachment(input)
else:
assert_never(input)

if isinstance(request.input, str):
yield await on_text(request.input)
elif isinstance(request.input, list):

is_list_of_tokens = False
for input in request.input:
if isinstance(input, str):
yield await on_text(input)
elif isinstance(input, list):
yield await on_tokens(input)
else:
is_list_of_tokens = True
break

if is_list_of_tokens:
yield await on_tokens(cast(_Tokens, request.input))

else:
assert_never(request.input)

if request.custom_input is None:
return

for input in request.custom_input:
if isinstance(input, (str, Attachment)):
yield await _on_str_or_attachment(input)
elif isinstance(input, list):
if len(input) == 0:
pass
elif len(input) == 1:
yield await _on_str_or_attachment(input[0])
else:
yield await on_mixed(input)
else:
assert_never(input)
153 changes: 153 additions & 0 deletions aidial_adapter_openai/embeddings/azure_ai_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Adapter for multi-modal embeddings provided by Azure AI Vision service.

1. Conceptual overview: https://aka.ms/image-retrieval
2. How-to article: https://learn.microsoft.com/en-us/azure/ai-services/computer-vision/how-to/image-retrieval?tabs=python
3. REST API (image url, binary image, text): https://learn.microsoft.com/en-gb/rest/api/computervision/image-retrieval?view=rest-computervision-v4.0-preview%20(2023-04-01)
4. A plug-in for Azure Search service: https://learn.microsoft.com/en-gb/azure/search/vector-search-vectorizer-ai-services-vision
5. Example of usage in a RAG: https://github.com/Azure-Samples/azure-search-openai-demo/blob/0946893fe904cab1e89de2a38c4421e38d508608/app/backend/prepdocslib/embeddings.py#L226-L260

Note that currently there is no Python SDK for this API.
There is SDK for Image Analysis 4.0 API, but it doesn't cover the multi-modal embeddings API: https://learn.microsoft.com/en-us/azure/ai-services/computer-vision/how-to/call-analyze-image-40?pivots=programming-language-python

Input requirements:

1. The file size of the image must be less than 20 megabytes (MB).
2. The dimensions of the image must be greater than 10 x 10 pixels and less than 16,000 x 16,000 pixels.
3. The text string must be between (inclusive) one word and 70 words.
4. Supported media types: "application/octet-stream", "image/jpeg", "image/gif", "image/tiff", "image/bmp", "image/png"

Output characteristics:

1. The vector embeddings are normalized.
2. Image and text vector embeddings have 1024 dimensions.

Limitations:

1. Batching isn't supported.

Note that when both "url" and "text" fields are sent in a request,
the "text" field is ignored.
"""

import asyncio
from typing import AsyncIterator, List, assert_never

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.embeddings.request import EmbeddingsRequest
from aidial_sdk.embeddings.response import Embedding, EmbeddingResponse, Usage
from pydantic import BaseModel

from aidial_adapter_openai.dial_api.embedding_inputs import (
collect_embedding_inputs,
)
from aidial_adapter_openai.dial_api.resource import AttachmentResource
from aidial_adapter_openai.dial_api.storage import FileStorage
from aidial_adapter_openai.utils.auth import OpenAICreds
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.resource import Resource

# The latest Image Analysis API offers two models:
# * version 2023-04-15 which supports text search in many languages,
# * the legacy 2022-04-11 model which supports only English.
_VERSION_PARAMS = {
"api-version": "2024-02-01",
"model-version": "2023-04-15",
}


def _get_auth_headers(creds: OpenAICreds) -> dict[str, str]:
if "api_key" in creds:
return {"Ocp-Apim-Subscription-Key": creds["api_key"]}

if "azure_ad_token" in creds:
return {"Authorization": f"Bearer {creds['azure_ad_token']}"}

raise ValueError("Invalid credentials")


class VectorizeResponse(BaseModel):
class Config:
extra = "allow"

vector: List[float]


async def embeddings(
creds: OpenAICreds,
deployment: str,
endpoint: str,
file_storage: FileStorage | None,
data: dict,
) -> EmbeddingResponse:
input = EmbeddingsRequest.parse_obj(data)

async def on_text(text: str) -> str:
return text

async def on_attachment(attachment: Attachment) -> Resource:
return await AttachmentResource(attachment=attachment).download(
file_storage
)

inputs_iter: AsyncIterator[str | Resource] = collect_embedding_inputs(
input,
on_text=on_text,
on_attachment=on_attachment,
)

inputs: List[str | Resource] = [input async for input in inputs_iter]

async def _get_embedding(input: str | Resource) -> VectorizeResponse:
if isinstance(input, str):
return await _get_text_embedding(creds, endpoint, input)
elif isinstance(input, Resource):
return await _get_image_embedding(creds, endpoint, input)
else:
assert_never(input)

tasks: List[asyncio.Task[VectorizeResponse]] = [
asyncio.create_task(_get_embedding(input)) for input in inputs
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
]

responses = await asyncio.gather(*tasks)
vectors = [
Embedding(embedding=r.vector, index=idx)
for idx, r in enumerate(responses)
]

n = len(vectors)
usage = Usage(prompt_tokens=n, total_tokens=n)

return EmbeddingResponse(model=deployment, data=vectors, usage=usage)


async def _get_image_embedding(
creds: OpenAICreds, endpoint: str, resource: Resource
) -> VectorizeResponse:
resp = await get_http_client().post(
url=endpoint.rstrip("/") + "/computervision/retrieval:vectorizeImage",
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
content=resource.data,
headers={
**_get_auth_headers(creds),
"content-type": resource.type,
},
params=_VERSION_PARAMS,
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
)

resp.raise_for_status()
return VectorizeResponse.parse_obj(resp.json())


async def _get_text_embedding(
creds: OpenAICreds, endpoint: str, text: str
) -> VectorizeResponse:
resp = await get_http_client().post(
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
url=endpoint.rstrip("/") + "/computervision/retrieval:vectorizeText",
json={"text": text},
headers=_get_auth_headers(creds),
params=_VERSION_PARAMS,
)

resp.raise_for_status()
return VectorizeResponse.parse_obj(resp.json())
Loading