diff --git a/aidial_adapter_openai/app.py b/aidial_adapter_openai/app.py index e6cceb8..cd348c8 100644 --- a/aidial_adapter_openai/app.py +++ b/aidial_adapter_openai/app.py @@ -1,13 +1,11 @@ from contextlib import asynccontextmanager -from typing import Annotated import pydantic from aidial_sdk._errors import pydantic_validation_exception_handler from aidial_sdk.exceptions import HTTPException as DialException -from aidial_sdk.exceptions import InvalidRequestError from aidial_sdk.telemetry.init import init_telemetry from aidial_sdk.telemetry.types import TelemetryConfig -from fastapi import Depends, FastAPI, Request +from fastapi import FastAPI, Request from fastapi.responses import Response from openai import ( APIConnectionError, @@ -18,41 +16,10 @@ ) from aidial_adapter_openai.app_config import ApplicationConfig -from aidial_adapter_openai.completions import chat_completion as completion -from aidial_adapter_openai.dalle3 import ( - chat_completion as dalle3_chat_completion, -) -from aidial_adapter_openai.databricks import ( - chat_completion as databricks_chat_completion, -) -from aidial_adapter_openai.dial_api.storage import create_file_storage -from aidial_adapter_openai.embeddings.azure_ai_vision import ( - embeddings as azure_ai_vision_embeddings, -) -from aidial_adapter_openai.embeddings.openai import ( - embeddings as openai_embeddings, -) -from aidial_adapter_openai.gpt import gpt_chat_completion -from aidial_adapter_openai.gpt4_multi_modal.chat_completion import ( - gpt4_vision_chat_completion, - gpt4o_chat_completion, -) -from aidial_adapter_openai.mistral import ( - chat_completion as mistral_chat_completion, -) -from aidial_adapter_openai.utils.auth import get_credentials +from aidial_adapter_openai.routers.chat_completion import chat_completion +from aidial_adapter_openai.routers.embeddings import embedding from aidial_adapter_openai.utils.http_client import get_http_client -from aidial_adapter_openai.utils.image_tokenizer import ( - ImageTokenizer, - get_image_tokenizer, -) from aidial_adapter_openai.utils.log_config import configure_loggers, logger -from aidial_adapter_openai.utils.parsers import completions_parser, parse_body -from aidial_adapter_openai.utils.streaming import create_server_response -from aidial_adapter_openai.utils.tokenizer import ( - MultiModalTokenizer, - PlainTextTokenizer, -) @asynccontextmanager @@ -62,205 +29,6 @@ async def lifespan(app: FastAPI): await get_http_client().aclose() -def create_app( - app_config: ApplicationConfig | None = None, - to_init_telemetry: bool = True, - to_configure_loggers: bool = True, -) -> FastAPI: - app = FastAPI(lifespan=lifespan) - - if app_config is None: - app_config = ApplicationConfig.from_env() - - app.state.app_config = app_config - - if to_init_telemetry: - init_telemetry(app, TelemetryConfig()) - - if to_configure_loggers: - configure_loggers() - - return app - - -def get_app_config(request: Request) -> ApplicationConfig: - return request.app.state.app_config - - -def get_api_version(request: Request) -> str: - api_version = request.query_params.get("api-version", "") - app_config = get_app_config(request) - api_version = app_config.API_VERSIONS_MAPPING.get(api_version, api_version) - - if api_version == "": - raise InvalidRequestError("api-version is a required query parameter") - - return api_version - - -def _get_image_tokenizer( - deployment_id: str, app_config: ApplicationConfig -) -> ImageTokenizer: - image_tokenizer = get_image_tokenizer(deployment_id, app_config) - if not image_tokenizer: - raise RuntimeError( - f"No image tokenizer found for deployment {deployment_id}" - ) - return image_tokenizer - - -app = create_app() - - -@app.post("/openai/deployments/{deployment_id:path}/chat/completions") -async def chat_completion( - deployment_id: str, - request: Request, - app_config: Annotated[ApplicationConfig, Depends(get_app_config)], -): - - data = await parse_body(request) - - is_stream = bool(data.get("stream")) - - emulate_streaming = ( - deployment_id in app_config.NON_STREAMING_DEPLOYMENTS and is_stream - ) - - if emulate_streaming: - data["stream"] = False - - return create_server_response( - emulate_streaming, - await call_chat_completion( - deployment_id, data, is_stream, request, app_config - ), - ) - - -async def call_chat_completion( - deployment_id: str, - data: dict, - is_stream: bool, - request: Request, - app_config: ApplicationConfig, -): - - # Azure OpenAI deployments ignore "model" request field, - # since the deployment id is already encoded in the endpoint path. - # This is not the case for non-Azure OpenAI deployments, so - # they require the "model" field to be set. - # However, openai==1.33.0 requires the "model" field for **both** - # Azure and non-Azure deployments. - # Therefore, we provide the "model" field for all deployments here. - # The same goes for /embeddings endpoint. - data["model"] = deployment_id - - creds = await get_credentials(request) - api_version = get_api_version(request) - - upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] - - if completions_endpoint := completions_parser.parse(upstream_endpoint): - return await completion( - data, - completions_endpoint, - creds, - api_version, - deployment_id, - app_config, - ) - if deployment_id in app_config.DALLE3_DEPLOYMENTS: - storage = create_file_storage("images", request.headers) - return await dalle3_chat_completion( - data, - upstream_endpoint, - creds, - is_stream, - storage, - app_config.DALLE3_AZURE_API_VERSION, - ) - - if deployment_id in app_config.MISTRAL_DEPLOYMENTS: - return await mistral_chat_completion(data, upstream_endpoint, creds) - - if deployment_id in app_config.DATABRICKS_DEPLOYMENTS: - return await databricks_chat_completion(data, upstream_endpoint, creds) - - text_tokenizer_model = app_config.MODEL_ALIASES.get( - deployment_id, deployment_id - ) - - if deployment_id in app_config.GPT4_VISION_DEPLOYMENTS: - tokenizer = MultiModalTokenizer( - "gpt-4", _get_image_tokenizer(deployment_id, app_config) - ) - return await gpt4_vision_chat_completion( - data, - deployment_id, - upstream_endpoint, - creds, - is_stream, - create_file_storage("images", request.headers), - api_version, - tokenizer, - ) - - if deployment_id in ( - *app_config.GPT4O_DEPLOYMENTS, - *app_config.GPT4O_MINI_DEPLOYMENTS, - ): - tokenizer = MultiModalTokenizer( - text_tokenizer_model, - _get_image_tokenizer(deployment_id, app_config), - ) - return await gpt4o_chat_completion( - data, - deployment_id, - upstream_endpoint, - creds, - is_stream, - create_file_storage("images", request.headers), - api_version, - tokenizer, - ) - - tokenizer = PlainTextTokenizer(model=text_tokenizer_model) - return await gpt_chat_completion( - data, - deployment_id, - upstream_endpoint, - creds, - api_version, - tokenizer, - ) - - -@app.post("/openai/deployments/{deployment_id:path}/embeddings") -async def embedding( - deployment_id: str, - request: Request, - app_config: Annotated[ApplicationConfig, Depends(get_app_config)], -): - data = await parse_body(request) - - # See note for /chat/completions endpoint - data["model"] = deployment_id - - creds = await get_credentials(request) - api_version = get_api_version(request) - upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] - - if deployment_id in app_config.AZURE_AI_VISION_DEPLOYMENTS: - storage = create_file_storage("images", request.headers) - return await azure_ai_vision_embeddings( - creds, deployment_id, upstream_endpoint, storage, data - ) - - return await openai_embeddings(creds, upstream_endpoint, api_version, data) - - -@app.exception_handler(OpenAIError) def openai_exception_handler(request: Request, e: DialException): if isinstance(e, APIStatusError): r = e.response @@ -303,16 +71,43 @@ def openai_exception_handler(request: Request, e: DialException): ) -@app.exception_handler(pydantic.ValidationError) def pydantic_exception_handler(request: Request, exc: pydantic.ValidationError): return pydantic_validation_exception_handler(request, exc) -@app.exception_handler(DialException) def dial_exception_handler(request: Request, exc: DialException): return exc.to_fastapi_response() -@app.get("/health") -def health(): - return {"status": "ok"} +def create_app( + app_config: ApplicationConfig | None = None, + to_init_telemetry: bool = True, +) -> FastAPI: + app = FastAPI(lifespan=lifespan) + + if app_config is None: + app_config = ApplicationConfig.from_env() + + app.state.app_config = app_config + + if to_init_telemetry: + init_telemetry(app, TelemetryConfig()) + + configure_loggers() + + @app.get("/health") + def health(): + return {"status": "ok"} + + app.post("/openai/deployments/{deployment_id:path}/embeddings")(embedding) + app.post("/openai/deployments/{deployment_id:path}/chat/completions")( + chat_completion + ) + app.exception_handler(OpenAIError)(openai_exception_handler) + app.exception_handler(pydantic.ValidationError)(pydantic_exception_handler) + app.exception_handler(DialException)(dial_exception_handler) + + return app + + +app = create_app() diff --git a/aidial_adapter_openai/app_config.py b/aidial_adapter_openai/app_config.py index 45bdd94..d024e4c 100644 --- a/aidial_adapter_openai/app_config.py +++ b/aidial_adapter_openai/app_config.py @@ -1,39 +1,78 @@ +import json +import os from typing import Dict, List -from pydantic import BaseModel, Field +from pydantic import BaseModel -import aidial_adapter_openai.env as env +from aidial_adapter_openai.utils.env import get_env_bool +from aidial_adapter_openai.utils.log_config import logger +from aidial_adapter_openai.utils.parsers import parse_deployment_list + + +def _get_eliminate_empty_choices() -> bool: + old_name = "FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS" + new_name = "ELIMINATE_EMPTY_CHOICES" + + if old_name in os.environ: + logger.warning( + f"{old_name} environment variable is deprecated. Use {new_name} instead." + ) + return get_env_bool(old_name, False) + + return get_env_bool(new_name, False) class ApplicationConfig(BaseModel): - MODEL_ALIASES: Dict[str, str] = Field(default_factory=dict) - DALLE3_DEPLOYMENTS: List[str] = Field(default_factory=list) - GPT4_VISION_DEPLOYMENTS: List[str] = Field(default_factory=list) - MISTRAL_DEPLOYMENTS: List[str] = Field(default_factory=list) - DATABRICKS_DEPLOYMENTS: List[str] = Field(default_factory=list) - GPT4O_DEPLOYMENTS: List[str] = Field(default_factory=list) - GPT4O_MINI_DEPLOYMENTS: List[str] = Field(default_factory=list) - AZURE_AI_VISION_DEPLOYMENTS: List[str] = Field(default_factory=list) - API_VERSIONS_MAPPING: Dict[str, str] = Field(default_factory=dict) - COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = Field( - default_factory=dict - ) - DALLE3_AZURE_API_VERSION: str = Field(default="2024-02-01") - NON_STREAMING_DEPLOYMENTS: List[str] = Field(default_factory=list) + MODEL_ALIASES: Dict[str, str] = {} + DALLE3_DEPLOYMENTS: List[str] = [] + GPT4_VISION_DEPLOYMENTS: List[str] = [] + MISTRAL_DEPLOYMENTS: List[str] = [] + DATABRICKS_DEPLOYMENTS: List[str] = [] + GPT4O_DEPLOYMENTS: List[str] = [] + GPT4O_MINI_DEPLOYMENTS: List[str] = [] + AZURE_AI_VISION_DEPLOYMENTS: List[str] = [] + API_VERSIONS_MAPPING: Dict[str, str] = {} + COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = {} + DALLE3_AZURE_API_VERSION: str = "2024-02-01" + NON_STREAMING_DEPLOYMENTS: List[str] = [] + ELIMINATE_EMPTY_CHOICES: bool = False @classmethod def from_env(cls) -> "ApplicationConfig": return cls( - MODEL_ALIASES=env.MODEL_ALIASES, - DALLE3_DEPLOYMENTS=env.DALLE3_DEPLOYMENTS, - GPT4_VISION_DEPLOYMENTS=env.GPT4_VISION_DEPLOYMENTS, - MISTRAL_DEPLOYMENTS=env.MISTRAL_DEPLOYMENTS, - DATABRICKS_DEPLOYMENTS=env.DATABRICKS_DEPLOYMENTS, - GPT4O_DEPLOYMENTS=env.GPT4O_DEPLOYMENTS, - GPT4O_MINI_DEPLOYMENTS=env.GPT4O_MINI_DEPLOYMENTS, - AZURE_AI_VISION_DEPLOYMENTS=env.AZURE_AI_VISION_DEPLOYMENTS, - API_VERSIONS_MAPPING=env.API_VERSIONS_MAPPING, - COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES=env.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES, - DALLE3_AZURE_API_VERSION=env.DALLE3_AZURE_API_VERSION, - NON_STREAMING_DEPLOYMENTS=env.NON_STREAMING_DEPLOYMENTS, + MODEL_ALIASES=json.loads(os.getenv("MODEL_ALIASES", "{}")), + DALLE3_DEPLOYMENTS=parse_deployment_list( + os.getenv("DALLE3_DEPLOYMENTS") + ), + GPT4_VISION_DEPLOYMENTS=parse_deployment_list( + os.getenv("GPT4_VISION_DEPLOYMENTS") + ), + MISTRAL_DEPLOYMENTS=parse_deployment_list( + os.getenv("MISTRAL_DEPLOYMENTS") + ), + DATABRICKS_DEPLOYMENTS=parse_deployment_list( + os.getenv("DATABRICKS_DEPLOYMENTS") + ), + GPT4O_DEPLOYMENTS=parse_deployment_list( + os.getenv("GPT4O_DEPLOYMENTS") + ), + GPT4O_MINI_DEPLOYMENTS=parse_deployment_list( + os.getenv("GPT4O_MINI_DEPLOYMENTS") + ), + AZURE_AI_VISION_DEPLOYMENTS=parse_deployment_list( + os.getenv("AZURE_AI_VISION_DEPLOYMENTS") + ), + API_VERSIONS_MAPPING=json.loads( + os.getenv("API_VERSIONS_MAPPING", "{}") + ), + COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES=json.loads( + os.getenv("COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES") or "{}" + ), + DALLE3_AZURE_API_VERSION=os.getenv( + "DALLE3_AZURE_API_VERSION", "2024-02-01" + ), + NON_STREAMING_DEPLOYMENTS=parse_deployment_list( + os.getenv("NON_STREAMING_DEPLOYMENTS") + ), + ELIMINATE_EMPTY_CHOICES=_get_eliminate_empty_choices(), ) diff --git a/aidial_adapter_openai/env.py b/aidial_adapter_openai/env.py deleted file mode 100644 index e55ddea..0000000 --- a/aidial_adapter_openai/env.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -import os -from typing import Dict - -from aidial_adapter_openai.utils.env import get_env_bool -from aidial_adapter_openai.utils.log_config import logger -from aidial_adapter_openai.utils.parsers import parse_deployment_list - -MODEL_ALIASES: Dict[str, str] = json.loads(os.getenv("MODEL_ALIASES", "{}")) -DALLE3_DEPLOYMENTS = parse_deployment_list(os.getenv("DALLE3_DEPLOYMENTS")) -GPT4_VISION_DEPLOYMENTS = parse_deployment_list( - os.getenv("GPT4_VISION_DEPLOYMENTS") -) -MISTRAL_DEPLOYMENTS = parse_deployment_list(os.getenv("MISTRAL_DEPLOYMENTS")) -DATABRICKS_DEPLOYMENTS = parse_deployment_list( - os.getenv("DATABRICKS_DEPLOYMENTS") -) -GPT4O_DEPLOYMENTS = parse_deployment_list(os.getenv("GPT4O_DEPLOYMENTS")) -GPT4O_MINI_DEPLOYMENTS = parse_deployment_list( - os.getenv("GPT4O_MINI_DEPLOYMENTS") -) -API_VERSIONS_MAPPING: Dict[str, str] = json.loads( - os.getenv("API_VERSIONS_MAPPING", "{}") -) -COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = json.loads( - os.getenv("COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES") or "{}" -) -DALLE3_AZURE_API_VERSION = os.getenv("DALLE3_AZURE_API_VERSION", "2024-02-01") -NON_STREAMING_DEPLOYMENTS = parse_deployment_list( - os.getenv("NON_STREAMING_DEPLOYMENTS") -) -AZURE_AI_VISION_DEPLOYMENTS = parse_deployment_list( - os.getenv("AZURE_AI_VISION_DEPLOYMENTS") -) - - -def get_eliminate_empty_choices() -> bool: - old_name = "FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS" - new_name = "ELIMINATE_EMPTY_CHOICES" - - if old_name in os.environ: - logger.warning( - f"{old_name} environment variable is deprecated. Use {new_name} instead." - ) - return get_env_bool(old_name, False) - - return get_env_bool(new_name, False) diff --git a/aidial_adapter_openai/gpt.py b/aidial_adapter_openai/gpt.py index 5d6610d..d4c6cde 100644 --- a/aidial_adapter_openai/gpt.py +++ b/aidial_adapter_openai/gpt.py @@ -44,6 +44,7 @@ async def gpt_chat_completion( creds: OpenAICreds, api_version: str, tokenizer: PlainTextTokenizer, + eliminate_empty_choices: bool, ): discarded_messages = None estimated_prompt_tokens = None @@ -83,6 +84,7 @@ async def gpt_chat_completion( deployment=deployment_id, discarded_messages=discarded_messages, stream=map_stream(chunk_to_dict, response), + eliminate_empty_choices=eliminate_empty_choices, ) else: rest = response.to_dict() diff --git a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py index 26d82c7..216137d 100644 --- a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py +++ b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py @@ -143,6 +143,7 @@ async def gpt4o_chat_completion( file_storage: Optional[FileStorage], api_version: str, tokenizer: MultiModalTokenizer, + eliminate_empty_choices: bool, ): return await chat_completion( request, @@ -155,6 +156,7 @@ async def gpt4o_chat_completion( tokenizer, lambda x: x, None, + eliminate_empty_choices, ) @@ -167,6 +169,7 @@ async def gpt4_vision_chat_completion( file_storage: Optional[FileStorage], api_version: str, tokenizer: MultiModalTokenizer, + eliminate_empty_choices: bool, ): return await chat_completion( request, @@ -179,6 +182,7 @@ async def gpt4_vision_chat_completion( tokenizer, convert_gpt4v_to_gpt4_chunk, GPT4V_DEFAULT_MAX_TOKENS, + eliminate_empty_choices, ) @@ -193,6 +197,7 @@ async def chat_completion( tokenizer: MultiModalTokenizer, response_transformer: Callable[[dict], dict | None], default_max_tokens: Optional[int], + eliminate_empty_choices: bool, ): if request.get("n", 1) > 1: raise RequestValidationError("The deployment doesn't support n > 1") @@ -265,6 +270,7 @@ def debug_print(chunk: T) -> T: response_transformer, parse_openai_sse_stream(response), ), + eliminate_empty_choices=eliminate_empty_choices, ), ) else: diff --git a/aidial_adapter_openai/routers/__init__.py b/aidial_adapter_openai/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aidial_adapter_openai/routers/chat_completion.py b/aidial_adapter_openai/routers/chat_completion.py new file mode 100644 index 0000000..d6f86ac --- /dev/null +++ b/aidial_adapter_openai/routers/chat_completion.py @@ -0,0 +1,159 @@ +from typing import Annotated + +from fastapi import Depends, Request + +from aidial_adapter_openai.app_config import ApplicationConfig +from aidial_adapter_openai.completions import chat_completion as completion +from aidial_adapter_openai.dalle3 import ( + chat_completion as dalle3_chat_completion, +) +from aidial_adapter_openai.databricks import ( + chat_completion as databricks_chat_completion, +) +from aidial_adapter_openai.dial_api.storage import create_file_storage +from aidial_adapter_openai.gpt import gpt_chat_completion +from aidial_adapter_openai.gpt4_multi_modal.chat_completion import ( + gpt4_vision_chat_completion, + gpt4o_chat_completion, +) +from aidial_adapter_openai.mistral import ( + chat_completion as mistral_chat_completion, +) +from aidial_adapter_openai.utils.auth import get_credentials +from aidial_adapter_openai.utils.image_tokenizer import get_image_tokenizer +from aidial_adapter_openai.utils.parsers import completions_parser, parse_body +from aidial_adapter_openai.utils.request import ( + get_api_version, + get_request_app_config, +) +from aidial_adapter_openai.utils.streaming import create_server_response +from aidial_adapter_openai.utils.tokenizer import ( + MultiModalTokenizer, + PlainTextTokenizer, +) + + +async def call_chat_completion( + deployment_id: str, + data: dict, + is_stream: bool, + request: Request, + app_config: ApplicationConfig, +): + + # Azure OpenAI deployments ignore "model" request field, + # since the deployment id is already encoded in the endpoint path. + # This is not the case for non-Azure OpenAI deployments, so + # they require the "model" field to be set. + # However, openai==1.33.0 requires the "model" field for **both** + # Azure and non-Azure deployments. + # Therefore, we provide the "model" field for all deployments here. + # The same goes for /embeddings endpoint. + data["model"] = deployment_id + + creds = await get_credentials(request) + api_version = get_api_version(request) + + upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] + + if completions_endpoint := completions_parser.parse(upstream_endpoint): + return await completion( + data, + completions_endpoint, + creds, + api_version, + deployment_id, + app_config, + ) + if deployment_id in app_config.DALLE3_DEPLOYMENTS: + storage = create_file_storage("images", request.headers) + return await dalle3_chat_completion( + data, + upstream_endpoint, + creds, + is_stream, + storage, + app_config.DALLE3_AZURE_API_VERSION, + ) + + if deployment_id in app_config.MISTRAL_DEPLOYMENTS: + return await mistral_chat_completion(data, upstream_endpoint, creds) + + if deployment_id in app_config.DATABRICKS_DEPLOYMENTS: + return await databricks_chat_completion(data, upstream_endpoint, creds) + + text_tokenizer_model = app_config.MODEL_ALIASES.get( + deployment_id, deployment_id + ) + + if deployment_id in app_config.GPT4_VISION_DEPLOYMENTS: + tokenizer = MultiModalTokenizer( + "gpt-4", get_image_tokenizer(deployment_id, app_config) + ) + return await gpt4_vision_chat_completion( + data, + deployment_id, + upstream_endpoint, + creds, + is_stream, + create_file_storage("images", request.headers), + api_version, + tokenizer, + app_config.ELIMINATE_EMPTY_CHOICES, + ) + + if deployment_id in ( + *app_config.GPT4O_DEPLOYMENTS, + *app_config.GPT4O_MINI_DEPLOYMENTS, + ): + tokenizer = MultiModalTokenizer( + text_tokenizer_model, + get_image_tokenizer(deployment_id, app_config), + ) + return await gpt4o_chat_completion( + data, + deployment_id, + upstream_endpoint, + creds, + is_stream, + create_file_storage("images", request.headers), + api_version, + tokenizer, + app_config.ELIMINATE_EMPTY_CHOICES, + ) + + tokenizer = PlainTextTokenizer(model=text_tokenizer_model) + return await gpt_chat_completion( + data, + deployment_id, + upstream_endpoint, + creds, + api_version, + tokenizer, + app_config.ELIMINATE_EMPTY_CHOICES, + ) + + +async def chat_completion( + deployment_id: str, + request: Request, + app_config: Annotated[ApplicationConfig, Depends(get_request_app_config)], +): + + data = await parse_body(request) + + is_stream = bool(data.get("stream")) + + emulate_streaming = ( + deployment_id in app_config.NON_STREAMING_DEPLOYMENTS and is_stream + ) + + if emulate_streaming: + data["stream"] = False + + return create_server_response( + emulate_streaming, + await call_chat_completion( + deployment_id, data, is_stream, request, app_config + ), + ) diff --git a/aidial_adapter_openai/routers/embeddings.py b/aidial_adapter_openai/routers/embeddings.py new file mode 100644 index 0000000..126d3ff --- /dev/null +++ b/aidial_adapter_openai/routers/embeddings.py @@ -0,0 +1,41 @@ +from typing import Annotated + +from fastapi import Depends, Request + +from aidial_adapter_openai.app_config import ApplicationConfig +from aidial_adapter_openai.dial_api.storage import create_file_storage +from aidial_adapter_openai.embeddings.azure_ai_vision import ( + embeddings as azure_ai_vision_embeddings, +) +from aidial_adapter_openai.embeddings.openai import ( + embeddings as openai_embeddings, +) +from aidial_adapter_openai.utils.auth import get_credentials +from aidial_adapter_openai.utils.parsers import parse_body +from aidial_adapter_openai.utils.request import ( + get_api_version, + get_request_app_config, +) + + +async def embedding( + deployment_id: str, + request: Request, + app_config: Annotated[ApplicationConfig, Depends(get_request_app_config)], +): + data = await parse_body(request) + + # See note for /chat/completions endpoint + data["model"] = deployment_id + + creds = await get_credentials(request) + api_version = get_api_version(request) + upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] + + if deployment_id in app_config.AZURE_AI_VISION_DEPLOYMENTS: + storage = create_file_storage("images", request.headers) + return await azure_ai_vision_embeddings( + creds, deployment_id, upstream_endpoint, storage, data + ) + + return await openai_embeddings(creds, upstream_endpoint, api_version, data) diff --git a/aidial_adapter_openai/utils/image_tokenizer.py b/aidial_adapter_openai/utils/image_tokenizer.py index b6d328d..01eb467 100644 --- a/aidial_adapter_openai/utils/image_tokenizer.py +++ b/aidial_adapter_openai/utils/image_tokenizer.py @@ -4,7 +4,7 @@ """ import math -from typing import List, Tuple, assert_never +from typing import assert_never from pydantic import BaseModel @@ -57,19 +57,17 @@ def _compute_high_detail_tokens(self, width: int, height: int) -> int: def get_image_tokenizer( deployment_id: str, app_config: ApplicationConfig -) -> ImageTokenizer | None: - _TOKENIZERS: List[Tuple[ImageTokenizer, List[str]]] = [ - (GPT4O_IMAGE_TOKENIZER, app_config.GPT4O_DEPLOYMENTS), - (GPT4O_MINI_IMAGE_TOKENIZER, app_config.GPT4O_MINI_DEPLOYMENTS), - ( - GPT4_VISION_IMAGE_TOKENIZER, - app_config.GPT4_VISION_DEPLOYMENTS, - ), - ] - for tokenizer, ids in _TOKENIZERS: - if deployment_id in ids: - return tokenizer - return None +) -> ImageTokenizer: + if deployment_id in app_config.GPT4O_DEPLOYMENTS: + return GPT4O_IMAGE_TOKENIZER + elif deployment_id in app_config.GPT4O_MINI_DEPLOYMENTS: + return GPT4O_MINI_IMAGE_TOKENIZER + elif deployment_id in app_config.GPT4_VISION_DEPLOYMENTS: + return GPT4_VISION_IMAGE_TOKENIZER + else: + raise RuntimeError( + f"No image tokenizer found for deployment {deployment_id}" + ) def _fit_longest(width: int, height: int, size: int) -> tuple[int, int]: diff --git a/aidial_adapter_openai/utils/request.py b/aidial_adapter_openai/utils/request.py new file mode 100644 index 0000000..c8d6457 --- /dev/null +++ b/aidial_adapter_openai/utils/request.py @@ -0,0 +1,27 @@ +from aidial_sdk.exceptions import InvalidRequestError +from fastapi import FastAPI, Request + +from aidial_adapter_openai.app_config import ApplicationConfig + + +def set_app_config(app: FastAPI, app_config: ApplicationConfig): + app.state.app_config = app_config + + +def get_app_config(app: FastAPI) -> ApplicationConfig: + return app.state.app_config + + +def get_request_app_config(request: Request) -> ApplicationConfig: + return get_app_config(request.app) + + +def get_api_version(request: Request) -> str: + api_version = request.query_params.get("api-version", "") + app_config = get_request_app_config(request) + api_version = app_config.API_VERSIONS_MAPPING.get(api_version, api_version) + + if api_version == "": + raise InvalidRequestError("api-version is a required query parameter") + + return api_version diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index d677ef7..724ae00 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -10,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel -from aidial_adapter_openai.env import get_eliminate_empty_choices from aidial_adapter_openai.utils.chat_completion_response import ( ChatCompletionResponse, ChatCompletionStreamingChunk, @@ -18,8 +17,6 @@ from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream -ELIMINATE_EMPTY_CHOICES = get_eliminate_empty_choices() - def generate_id() -> str: return "chatcmpl-" + str(uuid4()) @@ -62,6 +59,7 @@ async def generate_stream( deployment: str, discarded_messages: Optional[list[int]], stream: AsyncIterator[dict], + eliminate_empty_choices: bool, ) -> AsyncIterator[dict]: empty_chunk = build_chunk( @@ -116,7 +114,7 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict: # when content filtering is enabled for a corresponding deployment. # The safety rating of the request is reported in this first chunk. # Here we withhold such a chunk and merge it later with a follow-up chunk. - if len(choices) == 0 and ELIMINATE_EMPTY_CHOICES: + if len(choices) == 0 and eliminate_empty_choices: buffer_chunk = chunk else: if last_chunk is not None: diff --git a/tests/conftest.py b/tests/conftest.py index 811cab1..a8efc95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,25 +1,29 @@ -from unittest.mock import patch - import httpx import pytest import pytest_asyncio from httpx import ASGITransport -from aidial_adapter_openai.app import app +from aidial_adapter_openai.app import create_app +from aidial_adapter_openai.utils.request import get_app_config -@pytest.fixture -def eliminate_empty_choices(): - with patch( - "aidial_adapter_openai.utils.streaming.ELIMINATE_EMPTY_CHOICES", True - ): - yield +@pytest_asyncio.fixture +def _app_instance(): + return create_app() @pytest_asyncio.fixture -async def test_app(): +async def test_app(_app_instance): async with httpx.AsyncClient( - transport=ASGITransport(app=app), # type: ignore + transport=ASGITransport(app=_app_instance), base_url="http://test-app.com", ) as client: yield client + + +@pytest.fixture +def eliminate_empty_choices(_app_instance): + app_config = get_app_config(_app_instance) + app_config.ELIMINATE_EMPTY_CHOICES = True + yield + app_config.ELIMINATE_EMPTY_CHOICES = False