diff --git a/aidial_adapter_openai/app.py b/aidial_adapter_openai/app.py index 11ce937..5d07629 100644 --- a/aidial_adapter_openai/app.py +++ b/aidial_adapter_openai/app.py @@ -1,64 +1,22 @@ from contextlib import asynccontextmanager import pydantic -from aidial_sdk._errors import pydantic_validation_exception_handler from aidial_sdk.exceptions import HTTPException as DialException -from aidial_sdk.exceptions import InvalidRequestError -from aidial_sdk.telemetry.init import init_telemetry +from aidial_sdk.telemetry.init import init_telemetry as sdk_init_telemetry from aidial_sdk.telemetry.types import TelemetryConfig -from fastapi import FastAPI, Request -from fastapi.responses import Response -from openai import ( - APIConnectionError, - APIError, - APIStatusError, - APITimeoutError, - OpenAIError, -) +from fastapi import FastAPI +from openai import OpenAIError -from aidial_adapter_openai.completions import chat_completion as completion -from aidial_adapter_openai.dalle3 import ( - chat_completion as dalle3_chat_completion, -) -from aidial_adapter_openai.databricks import ( - chat_completion as databricks_chat_completion, -) -from aidial_adapter_openai.dial_api.storage import create_file_storage -from aidial_adapter_openai.embeddings.azure_ai_vision import ( - embeddings as azure_ai_vision_embeddings, -) -from aidial_adapter_openai.embeddings.openai import ( - embeddings as openai_embeddings, -) -from aidial_adapter_openai.env import ( - API_VERSIONS_MAPPING, - AZURE_AI_VISION_DEPLOYMENTS, - DALLE3_AZURE_API_VERSION, - DALLE3_DEPLOYMENTS, - DATABRICKS_DEPLOYMENTS, - GPT4_VISION_DEPLOYMENTS, - MISTRAL_DEPLOYMENTS, - MODEL_ALIASES, - NON_STREAMING_DEPLOYMENTS, -) -from aidial_adapter_openai.gpt import gpt_chat_completion -from aidial_adapter_openai.gpt4_multi_modal.chat_completion import ( - gpt4_vision_chat_completion, - gpt4o_chat_completion, +import aidial_adapter_openai.endpoints as endpoints +from aidial_adapter_openai.app_config import ApplicationConfig +from aidial_adapter_openai.exception_handlers import ( + dial_exception_handler, + openai_exception_handler, + pydantic_exception_handler, ) -from aidial_adapter_openai.mistral import ( - chat_completion as mistral_chat_completion, -) -from aidial_adapter_openai.utils.auth import get_credentials from aidial_adapter_openai.utils.http_client import get_http_client -from aidial_adapter_openai.utils.image_tokenizer import get_image_tokenizer from aidial_adapter_openai.utils.log_config import configure_loggers, logger -from aidial_adapter_openai.utils.parsers import completions_parser, parse_body -from aidial_adapter_openai.utils.streaming import create_server_response -from aidial_adapter_openai.utils.tokenizer import ( - MultiModalTokenizer, - PlainTextTokenizer, -) +from aidial_adapter_openai.utils.request import set_app_config @asynccontextmanager @@ -68,202 +26,30 @@ async def lifespan(app: FastAPI): await get_http_client().aclose() -app = FastAPI(lifespan=lifespan) - - -init_telemetry(app, TelemetryConfig()) -configure_loggers() - - -def get_api_version(request: Request): - api_version = request.query_params.get("api-version", "") - api_version = API_VERSIONS_MAPPING.get(api_version, api_version) +def create_app( + app_config: ApplicationConfig | None = None, + init_telemetry: bool = True, +) -> FastAPI: + app = FastAPI(lifespan=lifespan) + set_app_config(app, app_config or ApplicationConfig.from_env()) - if api_version == "": - raise InvalidRequestError("api-version is a required query parameter") + if init_telemetry: + sdk_init_telemetry(app, TelemetryConfig()) - return api_version + configure_loggers() - -@app.post("/openai/deployments/{deployment_id:path}/chat/completions") -async def chat_completion(deployment_id: str, request: Request): - - data = await parse_body(request) - - is_stream = bool(data.get("stream")) - - emulate_streaming = deployment_id in NON_STREAMING_DEPLOYMENTS and is_stream - - if emulate_streaming: - data["stream"] = False - - return create_server_response( - emulate_streaming, - await call_chat_completion(deployment_id, data, is_stream, request), + app.get("/health")(endpoints.health) + app.post("/openai/deployments/{deployment_id:path}/embeddings")( + endpoints.embedding ) - - -async def call_chat_completion( - deployment_id: str, data: dict, is_stream: bool, request: Request -): - - # Azure OpenAI deployments ignore "model" request field, - # since the deployment id is already encoded in the endpoint path. - # This is not the case for non-Azure OpenAI deployments, so - # they require the "model" field to be set. - # However, openai==1.33.0 requires the "model" field for **both** - # Azure and non-Azure deployments. - # Therefore, we provide the "model" field for all deployments here. - # The same goes for /embeddings endpoint. - data["model"] = deployment_id - - creds = await get_credentials(request) - api_version = get_api_version(request) - - upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] - - if completions_endpoint := completions_parser.parse(upstream_endpoint): - return await completion( - data, - completions_endpoint, - creds, - api_version, - deployment_id, - ) - - if deployment_id in DALLE3_DEPLOYMENTS: - storage = create_file_storage("images", request.headers) - return await dalle3_chat_completion( - data, - upstream_endpoint, - creds, - is_stream, - storage, - DALLE3_AZURE_API_VERSION, - ) - - if deployment_id in MISTRAL_DEPLOYMENTS: - return await mistral_chat_completion(data, upstream_endpoint, creds) - - if deployment_id in DATABRICKS_DEPLOYMENTS: - return await databricks_chat_completion(data, upstream_endpoint, creds) - - text_tokenizer_model = MODEL_ALIASES.get(deployment_id, deployment_id) - - if image_tokenizer := get_image_tokenizer(deployment_id): - storage = create_file_storage("images", request.headers) - - if deployment_id in GPT4_VISION_DEPLOYMENTS: - tokenizer = MultiModalTokenizer("gpt-4", image_tokenizer) - return await gpt4_vision_chat_completion( - data, - deployment_id, - upstream_endpoint, - creds, - is_stream, - storage, - api_version, - tokenizer, - ) - else: - tokenizer = MultiModalTokenizer( - text_tokenizer_model, image_tokenizer - ) - return await gpt4o_chat_completion( - data, - deployment_id, - upstream_endpoint, - creds, - is_stream, - storage, - api_version, - tokenizer, - ) - - tokenizer = PlainTextTokenizer(model=text_tokenizer_model) - return await gpt_chat_completion( - data, - deployment_id, - upstream_endpoint, - creds, - api_version, - tokenizer, + app.post("/openai/deployments/{deployment_id:path}/chat/completions")( + endpoints.chat_completion ) + app.exception_handler(OpenAIError)(openai_exception_handler) + app.exception_handler(pydantic.ValidationError)(pydantic_exception_handler) + app.exception_handler(DialException)(dial_exception_handler) - -@app.post("/openai/deployments/{deployment_id:path}/embeddings") -async def embedding(deployment_id: str, request: Request): - data = await parse_body(request) - - # See note for /chat/completions endpoint - data["model"] = deployment_id - - creds = await get_credentials(request) - api_version = get_api_version(request) - upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] - - if deployment_id in AZURE_AI_VISION_DEPLOYMENTS: - storage = create_file_storage("images", request.headers) - return await azure_ai_vision_embeddings( - creds, deployment_id, upstream_endpoint, storage, data - ) - - return await openai_embeddings(creds, upstream_endpoint, api_version, data) - - -@app.exception_handler(OpenAIError) -def openai_exception_handler(request: Request, e: DialException): - if isinstance(e, APIStatusError): - r = e.response - headers = r.headers - - # Avoid encoding the error message when the original response was encoded. - if "Content-Encoding" in headers: - del headers["Content-Encoding"] - - return Response( - content=r.content, - status_code=r.status_code, - headers=headers, - ) - - if isinstance(e, APITimeoutError): - raise DialException( - status_code=504, - type="timeout", - message="Request timed out", - display_message="Request timed out. Please try again later.", - ) - - if isinstance(e, APIConnectionError): - raise DialException( - status_code=502, - type="connection", - message="Error communicating with OpenAI", - display_message="OpenAI server is not responsive. Please try again later.", - ) - - if isinstance(e, APIError): - raise DialException( - status_code=getattr(e, "status_code", None) or 500, - message=e.message, - type=e.type, - code=e.code, - param=e.param, - display_message=None, - ) - - -@app.exception_handler(pydantic.ValidationError) -def pydantic_exception_handler(request: Request, exc: pydantic.ValidationError): - return pydantic_validation_exception_handler(request, exc) - - -@app.exception_handler(DialException) -def dial_exception_handler(request: Request, exc: DialException): - return exc.to_fastapi_response() + return app -@app.get("/health") -def health(): - return {"status": "ok"} +app = create_app() diff --git a/aidial_adapter_openai/app_config.py b/aidial_adapter_openai/app_config.py new file mode 100644 index 0000000..3180172 --- /dev/null +++ b/aidial_adapter_openai/app_config.py @@ -0,0 +1,105 @@ +import json +import os +from typing import Dict, List + +from pydantic import BaseModel + +from aidial_adapter_openai.constant import ChatCompletionDeploymentType +from aidial_adapter_openai.utils.env import get_env_bool +from aidial_adapter_openai.utils.json import remove_nones +from aidial_adapter_openai.utils.log_config import logger + + +class ApplicationConfig(BaseModel): + MODEL_ALIASES: Dict[str, str] = {} + DALLE3_DEPLOYMENTS: List[str] = [] + GPT4_VISION_DEPLOYMENTS: List[str] = [] + MISTRAL_DEPLOYMENTS: List[str] = [] + DATABRICKS_DEPLOYMENTS: List[str] = [] + GPT4O_DEPLOYMENTS: List[str] = [] + GPT4O_MINI_DEPLOYMENTS: List[str] = [] + AZURE_AI_VISION_DEPLOYMENTS: List[str] = [] + API_VERSIONS_MAPPING: Dict[str, str] = {} + COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = {} + DALLE3_AZURE_API_VERSION: str = "2024-02-01" + NON_STREAMING_DEPLOYMENTS: List[str] = [] + ELIMINATE_EMPTY_CHOICES: bool = False + + def get_chat_completion_deployment_type( + self, deployment_id: str + ) -> ChatCompletionDeploymentType: + if deployment_id in self.DALLE3_DEPLOYMENTS: + return ChatCompletionDeploymentType.DALLE3 + elif deployment_id in self.GPT4_VISION_DEPLOYMENTS: + return ChatCompletionDeploymentType.GPT4_VISION + elif deployment_id in self.MISTRAL_DEPLOYMENTS: + return ChatCompletionDeploymentType.MISTRAL + elif deployment_id in self.DATABRICKS_DEPLOYMENTS: + return ChatCompletionDeploymentType.DATABRICKS + elif deployment_id in self.GPT4O_DEPLOYMENTS: + return ChatCompletionDeploymentType.GPT4O + elif deployment_id in self.GPT4O_MINI_DEPLOYMENTS: + return ChatCompletionDeploymentType.GPT4O_MINI + else: + return ChatCompletionDeploymentType.GPT_TEXT_ONLY + + @classmethod + def from_env(cls) -> "ApplicationConfig": + def _parse_env_deployments(deployments_key: str) -> List[str] | None: + deployments_value = os.getenv(deployments_key) + if deployments_value is None: + return None + return list(map(str.strip, (deployments_value).split(","))) + + def _parse_env_dict(key: str) -> Dict[str, str] | None: + value = os.getenv(key) + return json.loads(value) if value else None + + def _parse_eliminate_empty_choices() -> bool | None: + old_name = "FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS" + new_name = "ELIMINATE_EMPTY_CHOICES" + + if old_name in os.environ: + logger.warning( + f"{old_name} environment variable is deprecated. Use {new_name} instead." + ) + return get_env_bool(old_name) + elif new_name in os.environ: + return get_env_bool(new_name) + + return None + + deployment_fields = { + deployment_key: _parse_env_deployments(deployment_key) + for deployment_key in ( + "DALLE3_DEPLOYMENTS", + "GPT4_VISION_DEPLOYMENTS", + "MISTRAL_DEPLOYMENTS", + "DATABRICKS_DEPLOYMENTS", + "GPT4O_DEPLOYMENTS", + "GPT4O_MINI_DEPLOYMENTS", + "AZURE_AI_VISION_DEPLOYMENTS", + "NON_STREAMING_DEPLOYMENTS", + ) + } + dict_fields = { + key: _parse_env_dict(key) + for key in ( + "MODEL_ALIASES", + "API_VERSIONS_MAPPING", + "COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES", + ) + } + + return cls( + **remove_nones( + { + **deployment_fields, + **dict_fields, + "DALLE3_AZURE_API_VERSION": os.getenv( + "DALLE3_AZURE_API_VERSION" + ), + "ELIMINATE_EMPTY_CHOICES": _parse_eliminate_empty_choices(), + } + ), + ) diff --git a/aidial_adapter_openai/completions.py b/aidial_adapter_openai/completions.py index 90834b5..4f11027 100644 --- a/aidial_adapter_openai/completions.py +++ b/aidial_adapter_openai/completions.py @@ -4,7 +4,7 @@ from openai import AsyncStream from openai.types import Completion -from aidial_adapter_openai.env import COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES +from aidial_adapter_openai.app_config import ApplicationConfig from aidial_adapter_openai.utils.auth import OpenAICreds from aidial_adapter_openai.utils.parsers import ( AzureOpenAIEndpoint, @@ -46,6 +46,7 @@ async def chat_completion( creds: OpenAICreds, api_version: str, deployment_id: str, + app_config: ApplicationConfig, ): if data.get("n") or 1 > 1: @@ -60,7 +61,9 @@ async def chat_completion( prompt = messages[-1].get("content") or "" if ( - template := COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get(deployment_id) + template := app_config.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get( + deployment_id + ) ) is not None: prompt = template.format(prompt=prompt) diff --git a/aidial_adapter_openai/constant.py b/aidial_adapter_openai/constant.py index e69de29..c24090f 100644 --- a/aidial_adapter_openai/constant.py +++ b/aidial_adapter_openai/constant.py @@ -0,0 +1,11 @@ +from enum import StrEnum, auto + + +class ChatCompletionDeploymentType(StrEnum): + DALLE3 = auto() + MISTRAL = auto() + DATABRICKS = auto() + GPT4_VISION = auto() + GPT4O = auto() + GPT4O_MINI = auto() + GPT_TEXT_ONLY = auto() diff --git a/aidial_adapter_openai/endpoints/__init__.py b/aidial_adapter_openai/endpoints/__init__.py new file mode 100644 index 0000000..420e7fa --- /dev/null +++ b/aidial_adapter_openai/endpoints/__init__.py @@ -0,0 +1,3 @@ +from .chat_completion import chat_completion +from .embeddings import embedding +from .health import health diff --git a/aidial_adapter_openai/endpoints/chat_completion.py b/aidial_adapter_openai/endpoints/chat_completion.py new file mode 100644 index 0000000..4682c42 --- /dev/null +++ b/aidial_adapter_openai/endpoints/chat_completion.py @@ -0,0 +1,160 @@ +from typing import assert_never + +from fastapi import Request + +from aidial_adapter_openai.app_config import ApplicationConfig +from aidial_adapter_openai.completions import chat_completion as completion +from aidial_adapter_openai.constant import ChatCompletionDeploymentType +from aidial_adapter_openai.dalle3 import ( + chat_completion as dalle3_chat_completion, +) +from aidial_adapter_openai.databricks import ( + chat_completion as databricks_chat_completion, +) +from aidial_adapter_openai.dial_api.storage import create_file_storage +from aidial_adapter_openai.gpt import gpt_chat_completion +from aidial_adapter_openai.gpt4_multi_modal.chat_completion import ( + gpt4_vision_chat_completion, + gpt4o_chat_completion, +) +from aidial_adapter_openai.mistral import ( + chat_completion as mistral_chat_completion, +) +from aidial_adapter_openai.utils.auth import get_credentials +from aidial_adapter_openai.utils.image_tokenizer import get_image_tokenizer +from aidial_adapter_openai.utils.parsers import completions_parser, parse_body +from aidial_adapter_openai.utils.request import ( + get_api_version, + get_request_app_config, +) +from aidial_adapter_openai.utils.streaming import create_server_response +from aidial_adapter_openai.utils.tokenizer import ( + MultiModalTokenizer, + PlainTextTokenizer, +) + + +async def call_chat_completion( + deployment_id: str, + data: dict, + is_stream: bool, + request: Request, + app_config: ApplicationConfig, +): + + # Azure OpenAI deployments ignore "model" request field, + # since the deployment id is already encoded in the endpoint path. + # This is not the case for non-Azure OpenAI deployments, so + # they require the "model" field to be set. + # However, openai==1.33.0 requires the "model" field for **both** + # Azure and non-Azure deployments. + # Therefore, we provide the "model" field for all deployments here. + # The same goes for /embeddings endpoint. + data["model"] = deployment_id + + creds = await get_credentials(request) + api_version = get_api_version(request) + + upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] + + if completions_endpoint := completions_parser.parse(upstream_endpoint): + return await completion( + data, + completions_endpoint, + creds, + api_version, + deployment_id, + app_config, + ) + + deployment_type = app_config.get_chat_completion_deployment_type( + deployment_id + ) + match deployment_type: + case ChatCompletionDeploymentType.DALLE3: + storage = create_file_storage("images", request.headers) + return await dalle3_chat_completion( + data, + upstream_endpoint, + creds, + is_stream, + storage, + app_config.DALLE3_AZURE_API_VERSION, + ) + case ChatCompletionDeploymentType.MISTRAL: + return await mistral_chat_completion(data, upstream_endpoint, creds) + case ChatCompletionDeploymentType.DATABRICKS: + return await databricks_chat_completion( + data, upstream_endpoint, creds + ) + case ChatCompletionDeploymentType.GPT4_VISION: + tokenizer = MultiModalTokenizer( + "gpt-4", get_image_tokenizer(deployment_type) + ) + return await gpt4_vision_chat_completion( + data, + deployment_id, + upstream_endpoint, + creds, + is_stream, + create_file_storage("images", request.headers), + api_version, + tokenizer, + app_config.ELIMINATE_EMPTY_CHOICES, + ) + case ( + ChatCompletionDeploymentType.GPT4O + | ChatCompletionDeploymentType.GPT4O_MINI + ): + + tokenizer = MultiModalTokenizer( + app_config.MODEL_ALIASES.get(deployment_id, deployment_id), + get_image_tokenizer(deployment_type), + ) + return await gpt4o_chat_completion( + data, + deployment_id, + upstream_endpoint, + creds, + is_stream, + create_file_storage("images", request.headers), + api_version, + tokenizer, + app_config.ELIMINATE_EMPTY_CHOICES, + ) + case ChatCompletionDeploymentType.GPT_TEXT_ONLY: + tokenizer = PlainTextTokenizer( + model=app_config.MODEL_ALIASES.get(deployment_id, deployment_id) + ) + return await gpt_chat_completion( + data, + deployment_id, + upstream_endpoint, + creds, + api_version, + tokenizer, + app_config.ELIMINATE_EMPTY_CHOICES, + ) + case _: + assert_never(deployment_type) + + +async def chat_completion(deployment_id: str, request: Request): + app_config = get_request_app_config(request) + data = await parse_body(request) + + is_stream = bool(data.get("stream")) + + emulate_streaming = ( + deployment_id in app_config.NON_STREAMING_DEPLOYMENTS and is_stream + ) + + if emulate_streaming: + data["stream"] = False + + return create_server_response( + emulate_streaming, + await call_chat_completion( + deployment_id, data, is_stream, request, app_config + ), + ) diff --git a/aidial_adapter_openai/endpoints/embeddings.py b/aidial_adapter_openai/endpoints/embeddings.py new file mode 100644 index 0000000..9960036 --- /dev/null +++ b/aidial_adapter_openai/endpoints/embeddings.py @@ -0,0 +1,35 @@ +from fastapi import Request + +from aidial_adapter_openai.dial_api.storage import create_file_storage +from aidial_adapter_openai.embeddings.azure_ai_vision import ( + embeddings as azure_ai_vision_embeddings, +) +from aidial_adapter_openai.embeddings.openai import ( + embeddings as openai_embeddings, +) +from aidial_adapter_openai.utils.auth import get_credentials +from aidial_adapter_openai.utils.parsers import parse_body +from aidial_adapter_openai.utils.request import ( + get_api_version, + get_request_app_config, +) + + +async def embedding(deployment_id: str, request: Request): + app_config = get_request_app_config(request) + data = await parse_body(request) + + # See note for /chat/completions endpoint + data["model"] = deployment_id + + creds = await get_credentials(request) + api_version = get_api_version(request) + upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] + + if deployment_id in app_config.AZURE_AI_VISION_DEPLOYMENTS: + storage = create_file_storage("images", request.headers) + return await azure_ai_vision_embeddings( + creds, deployment_id, upstream_endpoint, storage, data + ) + + return await openai_embeddings(creds, upstream_endpoint, api_version, data) diff --git a/aidial_adapter_openai/endpoints/health.py b/aidial_adapter_openai/endpoints/health.py new file mode 100644 index 0000000..834f2a3 --- /dev/null +++ b/aidial_adapter_openai/endpoints/health.py @@ -0,0 +1,2 @@ +def health(): + return {"status": "ok"} diff --git a/aidial_adapter_openai/env.py b/aidial_adapter_openai/env.py deleted file mode 100644 index e55ddea..0000000 --- a/aidial_adapter_openai/env.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -import os -from typing import Dict - -from aidial_adapter_openai.utils.env import get_env_bool -from aidial_adapter_openai.utils.log_config import logger -from aidial_adapter_openai.utils.parsers import parse_deployment_list - -MODEL_ALIASES: Dict[str, str] = json.loads(os.getenv("MODEL_ALIASES", "{}")) -DALLE3_DEPLOYMENTS = parse_deployment_list(os.getenv("DALLE3_DEPLOYMENTS")) -GPT4_VISION_DEPLOYMENTS = parse_deployment_list( - os.getenv("GPT4_VISION_DEPLOYMENTS") -) -MISTRAL_DEPLOYMENTS = parse_deployment_list(os.getenv("MISTRAL_DEPLOYMENTS")) -DATABRICKS_DEPLOYMENTS = parse_deployment_list( - os.getenv("DATABRICKS_DEPLOYMENTS") -) -GPT4O_DEPLOYMENTS = parse_deployment_list(os.getenv("GPT4O_DEPLOYMENTS")) -GPT4O_MINI_DEPLOYMENTS = parse_deployment_list( - os.getenv("GPT4O_MINI_DEPLOYMENTS") -) -API_VERSIONS_MAPPING: Dict[str, str] = json.loads( - os.getenv("API_VERSIONS_MAPPING", "{}") -) -COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = json.loads( - os.getenv("COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES") or "{}" -) -DALLE3_AZURE_API_VERSION = os.getenv("DALLE3_AZURE_API_VERSION", "2024-02-01") -NON_STREAMING_DEPLOYMENTS = parse_deployment_list( - os.getenv("NON_STREAMING_DEPLOYMENTS") -) -AZURE_AI_VISION_DEPLOYMENTS = parse_deployment_list( - os.getenv("AZURE_AI_VISION_DEPLOYMENTS") -) - - -def get_eliminate_empty_choices() -> bool: - old_name = "FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS" - new_name = "ELIMINATE_EMPTY_CHOICES" - - if old_name in os.environ: - logger.warning( - f"{old_name} environment variable is deprecated. Use {new_name} instead." - ) - return get_env_bool(old_name, False) - - return get_env_bool(new_name, False) diff --git a/aidial_adapter_openai/exception_handlers.py b/aidial_adapter_openai/exception_handlers.py new file mode 100644 index 0000000..c98c122 --- /dev/null +++ b/aidial_adapter_openai/exception_handlers.py @@ -0,0 +1,56 @@ +import pydantic +from aidial_sdk._errors import pydantic_validation_exception_handler +from aidial_sdk.exceptions import HTTPException as DialException +from fastapi import Request +from fastapi.responses import Response +from openai import APIConnectionError, APIError, APIStatusError, APITimeoutError + + +def openai_exception_handler(request: Request, e: DialException): + if isinstance(e, APIStatusError): + r = e.response + headers = r.headers + + # Avoid encoding the error message when the original response was encoded. + if "Content-Encoding" in headers: + del headers["Content-Encoding"] + + return Response( + content=r.content, + status_code=r.status_code, + headers=headers, + ) + + if isinstance(e, APITimeoutError): + raise DialException( + status_code=504, + type="timeout", + message="Request timed out", + display_message="Request timed out. Please try again later.", + ) + + if isinstance(e, APIConnectionError): + raise DialException( + status_code=502, + type="connection", + message="Error communicating with OpenAI", + display_message="OpenAI server is not responsive. Please try again later.", + ) + + if isinstance(e, APIError): + raise DialException( + status_code=getattr(e, "status_code", None) or 500, + message=e.message, + type=e.type, + code=e.code, + param=e.param, + display_message=None, + ) + + +def pydantic_exception_handler(request: Request, exc: pydantic.ValidationError): + return pydantic_validation_exception_handler(request, exc) + + +def dial_exception_handler(request: Request, exc: DialException): + return exc.to_fastapi_response() diff --git a/aidial_adapter_openai/gpt.py b/aidial_adapter_openai/gpt.py index 5d6610d..d4c6cde 100644 --- a/aidial_adapter_openai/gpt.py +++ b/aidial_adapter_openai/gpt.py @@ -44,6 +44,7 @@ async def gpt_chat_completion( creds: OpenAICreds, api_version: str, tokenizer: PlainTextTokenizer, + eliminate_empty_choices: bool, ): discarded_messages = None estimated_prompt_tokens = None @@ -83,6 +84,7 @@ async def gpt_chat_completion( deployment=deployment_id, discarded_messages=discarded_messages, stream=map_stream(chunk_to_dict, response), + eliminate_empty_choices=eliminate_empty_choices, ) else: rest = response.to_dict() diff --git a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py index 26d82c7..216137d 100644 --- a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py +++ b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py @@ -143,6 +143,7 @@ async def gpt4o_chat_completion( file_storage: Optional[FileStorage], api_version: str, tokenizer: MultiModalTokenizer, + eliminate_empty_choices: bool, ): return await chat_completion( request, @@ -155,6 +156,7 @@ async def gpt4o_chat_completion( tokenizer, lambda x: x, None, + eliminate_empty_choices, ) @@ -167,6 +169,7 @@ async def gpt4_vision_chat_completion( file_storage: Optional[FileStorage], api_version: str, tokenizer: MultiModalTokenizer, + eliminate_empty_choices: bool, ): return await chat_completion( request, @@ -179,6 +182,7 @@ async def gpt4_vision_chat_completion( tokenizer, convert_gpt4v_to_gpt4_chunk, GPT4V_DEFAULT_MAX_TOKENS, + eliminate_empty_choices, ) @@ -193,6 +197,7 @@ async def chat_completion( tokenizer: MultiModalTokenizer, response_transformer: Callable[[dict], dict | None], default_max_tokens: Optional[int], + eliminate_empty_choices: bool, ): if request.get("n", 1) > 1: raise RequestValidationError("The deployment doesn't support n > 1") @@ -265,6 +270,7 @@ def debug_print(chunk: T) -> T: response_transformer, parse_openai_sse_stream(response), ), + eliminate_empty_choices=eliminate_empty_choices, ), ) else: diff --git a/aidial_adapter_openai/utils/image_tokenizer.py b/aidial_adapter_openai/utils/image_tokenizer.py index 0fe5bf9..7f9eb79 100644 --- a/aidial_adapter_openai/utils/image_tokenizer.py +++ b/aidial_adapter_openai/utils/image_tokenizer.py @@ -4,15 +4,11 @@ """ import math -from typing import List, Tuple, assert_never +from typing import Literal, assert_never from pydantic import BaseModel -from aidial_adapter_openai.env import ( - GPT4_VISION_DEPLOYMENTS, - GPT4O_DEPLOYMENTS, - GPT4O_MINI_DEPLOYMENTS, -) +from aidial_adapter_openai.constant import ChatCompletionDeploymentType from aidial_adapter_openai.utils.image import ImageDetail, resolve_detail_level @@ -58,18 +54,25 @@ def _compute_high_detail_tokens(self, width: int, height: int) -> int: low_detail_tokens=2833, tokens_per_tile=5667 ) -_TOKENIZERS: List[Tuple[ImageTokenizer, List[str]]] = [ - (GPT4O_IMAGE_TOKENIZER, GPT4O_DEPLOYMENTS), - (GPT4O_MINI_IMAGE_TOKENIZER, GPT4O_MINI_DEPLOYMENTS), - (GPT4_VISION_IMAGE_TOKENIZER, GPT4_VISION_DEPLOYMENTS), +MultiModalDeployments = Literal[ + ChatCompletionDeploymentType.GPT4O, + ChatCompletionDeploymentType.GPT4O_MINI, + ChatCompletionDeploymentType.GPT4_VISION, ] -def get_image_tokenizer(deployment_id: str) -> ImageTokenizer | None: - for tokenizer, ids in _TOKENIZERS: - if deployment_id in ids: - return tokenizer - return None +def get_image_tokenizer( + deployment_type: MultiModalDeployments, +) -> ImageTokenizer: + match deployment_type: + case ChatCompletionDeploymentType.GPT4O: + return GPT4O_IMAGE_TOKENIZER + case ChatCompletionDeploymentType.GPT4O_MINI: + return GPT4O_MINI_IMAGE_TOKENIZER + case ChatCompletionDeploymentType.GPT4_VISION: + return GPT4_VISION_IMAGE_TOKENIZER + case _: + assert_never(deployment_type) def _fit_longest(width: int, height: int, size: int) -> tuple[int, int]: diff --git a/aidial_adapter_openai/utils/json.py b/aidial_adapter_openai/utils/json.py new file mode 100644 index 0000000..2a7b8a6 --- /dev/null +++ b/aidial_adapter_openai/utils/json.py @@ -0,0 +1,2 @@ +def remove_nones(d: dict) -> dict: + return {k: v for k, v in d.items() if v is not None} diff --git a/aidial_adapter_openai/utils/parsers.py b/aidial_adapter_openai/utils/parsers.py index 8975093..3591ff9 100644 --- a/aidial_adapter_openai/utils/parsers.py +++ b/aidial_adapter_openai/utils/parsers.py @@ -1,7 +1,7 @@ import re from abc import ABC, abstractmethod from json import JSONDecodeError -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, TypedDict from aidial_sdk.exceptions import InvalidRequestError from fastapi import Request @@ -110,7 +110,3 @@ async def parse_body(request: Request) -> Dict[str, Any]: raise InvalidRequestError(str(data) + " is not of type 'object'") return data - - -def parse_deployment_list(deployments: str | None) -> List[str]: - return list(map(str.strip, (deployments or "").split(","))) diff --git a/aidial_adapter_openai/utils/request.py b/aidial_adapter_openai/utils/request.py new file mode 100644 index 0000000..c8d6457 --- /dev/null +++ b/aidial_adapter_openai/utils/request.py @@ -0,0 +1,27 @@ +from aidial_sdk.exceptions import InvalidRequestError +from fastapi import FastAPI, Request + +from aidial_adapter_openai.app_config import ApplicationConfig + + +def set_app_config(app: FastAPI, app_config: ApplicationConfig): + app.state.app_config = app_config + + +def get_app_config(app: FastAPI) -> ApplicationConfig: + return app.state.app_config + + +def get_request_app_config(request: Request) -> ApplicationConfig: + return get_app_config(request.app) + + +def get_api_version(request: Request) -> str: + api_version = request.query_params.get("api-version", "") + app_config = get_request_app_config(request) + api_version = app_config.API_VERSIONS_MAPPING.get(api_version, api_version) + + if api_version == "": + raise InvalidRequestError("api-version is a required query parameter") + + return api_version diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index d677ef7..724ae00 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -10,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel -from aidial_adapter_openai.env import get_eliminate_empty_choices from aidial_adapter_openai.utils.chat_completion_response import ( ChatCompletionResponse, ChatCompletionStreamingChunk, @@ -18,8 +17,6 @@ from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream -ELIMINATE_EMPTY_CHOICES = get_eliminate_empty_choices() - def generate_id() -> str: return "chatcmpl-" + str(uuid4()) @@ -62,6 +59,7 @@ async def generate_stream( deployment: str, discarded_messages: Optional[list[int]], stream: AsyncIterator[dict], + eliminate_empty_choices: bool, ) -> AsyncIterator[dict]: empty_chunk = build_chunk( @@ -116,7 +114,7 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict: # when content filtering is enabled for a corresponding deployment. # The safety rating of the request is reported in this first chunk. # Here we withhold such a chunk and merge it later with a follow-up chunk. - if len(choices) == 0 and ELIMINATE_EMPTY_CHOICES: + if len(choices) == 0 and eliminate_empty_choices: buffer_chunk = chunk else: if last_chunk is not None: diff --git a/tests/conftest.py b/tests/conftest.py index 811cab1..0a51517 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,25 +1,29 @@ -from unittest.mock import patch - import httpx import pytest import pytest_asyncio from httpx import ASGITransport -from aidial_adapter_openai.app import app +from aidial_adapter_openai.app import create_app +from aidial_adapter_openai.utils.request import get_app_config @pytest.fixture -def eliminate_empty_choices(): - with patch( - "aidial_adapter_openai.utils.streaming.ELIMINATE_EMPTY_CHOICES", True - ): - yield +def _app_instance(): + return create_app(init_telemetry=False) @pytest_asyncio.fixture -async def test_app(): +async def test_app(_app_instance): async with httpx.AsyncClient( - transport=ASGITransport(app=app), # type: ignore + transport=ASGITransport(app=_app_instance), base_url="http://test-app.com", ) as client: yield client + + +@pytest.fixture +def eliminate_empty_choices(_app_instance): + app_config = get_app_config(_app_instance) + app_config.ELIMINATE_EMPTY_CHOICES = True + yield + app_config.ELIMINATE_EMPTY_CHOICES = False