diff --git a/aidial_adapter_openai/app.py b/aidial_adapter_openai/app.py index 28e2d44..12fce06 100644 --- a/aidial_adapter_openai/app.py +++ b/aidial_adapter_openai/app.py @@ -1,14 +1,14 @@ import json import os -from typing import Dict +from typing import Awaitable, Dict, TypeVar from aidial_sdk.exceptions import HTTPException as DialException from aidial_sdk.telemetry.init import init_telemetry from aidial_sdk.telemetry.types import TelemetryConfig +from aidial_sdk.utils.errors import json_error from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, Response, StreamingResponse -from openai import ChatCompletion, Embedding, error -from openai.openai_object import OpenAIObject +from fastapi.responses import JSONResponse, Response +from openai import APIConnectionError, APIStatusError, APITimeoutError from aidial_adapter_openai.constant import DEFAULT_TIMEOUT from aidial_adapter_openai.dalle3 import ( @@ -17,6 +17,7 @@ from aidial_adapter_openai.databricks import ( chat_completion as databricks_chat_completion, ) +from aidial_adapter_openai.gpt import gpt_chat_completion from aidial_adapter_openai.gpt4_multi_modal.chat_completion import ( gpt4_vision_chat_completion, gpt4o_chat_completion, @@ -24,19 +25,16 @@ from aidial_adapter_openai.mistral import ( chat_completion as mistral_chat_completion, ) -from aidial_adapter_openai.openai_override import OpenAIException from aidial_adapter_openai.utils.auth import get_credentials from aidial_adapter_openai.utils.log_config import configure_loggers from aidial_adapter_openai.utils.parsers import ( - chat_completions_parser, embeddings_parser, parse_body, parse_deployment_list, ) -from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream +from aidial_adapter_openai.utils.reflection import call_with_extra_body from aidial_adapter_openai.utils.storage import create_file_storage -from aidial_adapter_openai.utils.streaming import generate_stream, map_stream -from aidial_adapter_openai.utils.tokens import Tokenizer, discard_messages +from aidial_adapter_openai.utils.tokens import Tokenizer app = FastAPI() @@ -62,20 +60,27 @@ ) dalle3_azure_api_version = os.getenv("DALLE3_AZURE_API_VERSION", "2024-02-01") +T = TypeVar("T") -async def handle_exceptions(call): + +async def handle_exceptions(call: Awaitable[T]) -> T | Response: try: return await call - except OpenAIException as e: - return Response(status_code=e.code, headers=e.headers, content=e.body) - except error.Timeout: + except APIStatusError as e: + r = e.response + return Response( + content=r.content, + status_code=r.status_code, + headers=r.headers, + ) + except APITimeoutError: raise DialException( "Request timed out", 504, "timeout", display_message="Request timed out. Please try again later.", ) - except error.APIConnectionError: + except APIConnectionError: raise DialException( "Error communicating with OpenAI", 502, @@ -90,7 +95,7 @@ def get_api_version(request: Request): if api_version == "": raise DialException( - "Api version is a required query parameter", + "api-version is a required query parameter", 400, "invalid_request_error", ) @@ -102,9 +107,19 @@ def get_api_version(request: Request): async def chat_completion(deployment_id: str, request: Request): data = await parse_body(request) + # Azure OpenAI deployments ignore "model" request field, + # since the deployment id is already encoded in the endpoint path. + # This is not the case for non-Azure OpenAI deployments, so + # they require the "model" field to be set. + # However, openai==1.33.0 requires the "model" field for **both** + # Azure and non-Azure deployments. + # Therefore, we provide the "model" field for all deployments here. + # The same goes for /embeddings endpoint. + data["model"] = deployment_id + is_stream = data.get("stream", False) - api_type, api_key = await get_credentials(request, chat_completions_parser) + creds = await get_credentials(request) upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] @@ -113,25 +128,20 @@ async def chat_completion(deployment_id: str, request: Request): return await dalle3_chat_completion( data, upstream_endpoint, - api_key, + creds, is_stream, storage, - api_type, dalle3_azure_api_version, ) - elif deployment_id in mistral_deployments: + + if deployment_id in mistral_deployments: return await handle_exceptions( - mistral_chat_completion(data, upstream_endpoint, api_key) + mistral_chat_completion(data, upstream_endpoint, creds) ) - elif deployment_id in databricks_deployments: + + if deployment_id in databricks_deployments: return await handle_exceptions( - databricks_chat_completion( - data, - deployment_id, - upstream_endpoint, - api_key, - api_type, - ) + databricks_chat_completion(data, upstream_endpoint, creds) ) api_version = get_api_version(request) @@ -142,10 +152,9 @@ async def chat_completion(deployment_id: str, request: Request): data, deployment_id, upstream_endpoint, - api_key, + creds, is_stream, storage, - api_type, api_version, ) @@ -159,97 +168,43 @@ async def chat_completion(deployment_id: str, request: Request): data, deployment_id, upstream_endpoint, - api_key, + creds, is_stream, storage, - api_type, api_version, tokenizer, ) ) - discarded_messages = None - if "max_prompt_tokens" in data: - max_prompt_tokens = data["max_prompt_tokens"] - if not isinstance(max_prompt_tokens, int): - raise DialException( - f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'", - 400, - "invalid_request_error", - ) - if max_prompt_tokens < 1: - raise DialException( - f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'", - 400, - "invalid_request_error", - ) - del data["max_prompt_tokens"] - - data["messages"], discarded_messages = discard_messages( - tokenizer, data["messages"], max_prompt_tokens - ) - - request_args = chat_completions_parser.parse( - upstream_endpoint - ).prepare_request_args(deployment_id) - - response = await handle_exceptions( - ChatCompletion().acreate( - api_key=api_key, - api_type=api_type, - api_version=api_version, - request_timeout=DEFAULT_TIMEOUT, - **(data | request_args), + return await handle_exceptions( + gpt_chat_completion( + data, + deployment_id, + upstream_endpoint, + creds, + api_version, + tokenizer, ) ) - if isinstance(response, Response): - return response - - if is_stream: - prompt_tokens = tokenizer.calculate_prompt_tokens(data["messages"]) - chunk_stream = map_stream(lambda obj: obj.to_dict_recursive(), response) - return StreamingResponse( - to_openai_sse_stream( - generate_stream( - prompt_tokens, - chunk_stream, - tokenizer, - deployment_id, - discarded_messages, - ) - ), - media_type="text/event-stream", - ) - else: - if discarded_messages is not None: - assert isinstance(response, OpenAIObject) - response = response.to_dict() | { - "statistics": {"discarded_messages": discarded_messages} - } - - return response - @app.post("/openai/deployments/{deployment_id}/embeddings") async def embedding(deployment_id: str, request: Request): data = await parse_body(request) - api_type, api_key = await get_credentials(request, embeddings_parser) + # See note for /chat/completions endpoint + data["model"] = deployment_id + + creds = await get_credentials(request) api_version = get_api_version(request) + upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] - request_args = embeddings_parser.parse( - request.headers["X-UPSTREAM-ENDPOINT"] - ).prepare_request_args(deployment_id) + client = embeddings_parser.parse(upstream_endpoint).get_client( + {**creds, "api_version": api_version, "timeout": DEFAULT_TIMEOUT} + ) return await handle_exceptions( - Embedding().acreate( - api_key=api_key, - api_type=api_type, - api_version=api_version, - request_timeout=DEFAULT_TIMEOUT, - **(data | request_args), - ) + call_with_extra_body(client.embeddings.create, data) ) @@ -257,15 +212,13 @@ async def embedding(deployment_id: str, request: Request): def exception_handler(request: Request, exc: DialException): return JSONResponse( status_code=exc.status_code, - content={ - "error": { - "message": exc.message, - "type": exc.type, - "param": exc.param, - "code": exc.code, - "display_message": exc.display_message, - } - }, + content=json_error( + message=exc.message, + type=exc.type, + param=exc.param, + code=exc.code, + display_message=exc.display_message, + ), ) diff --git a/aidial_adapter_openai/constant.py b/aidial_adapter_openai/constant.py index 49b8af8..42016e3 100644 --- a/aidial_adapter_openai/constant.py +++ b/aidial_adapter_openai/constant.py @@ -1 +1,4 @@ -DEFAULT_TIMEOUT = (10, 600) # connect timeout and total timeout +from openai import Timeout + +# connect timeout and total timeout +DEFAULT_TIMEOUT = Timeout(600, connect=10) diff --git a/aidial_adapter_openai/dalle3.py b/aidial_adapter_openai/dalle3.py index 27c6e6e..610de1b 100644 --- a/aidial_adapter_openai/dalle3.py +++ b/aidial_adapter_openai/dalle3.py @@ -4,7 +4,7 @@ from aidial_sdk.exceptions import HTTPException as DialException from fastapi.responses import JSONResponse, Response, StreamingResponse -from aidial_adapter_openai.utils.auth import get_auth_header +from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers from aidial_adapter_openai.utils.sse_stream import END_CHUNK from aidial_adapter_openai.utils.storage import FileStorage from aidial_adapter_openai.utils.streaming import ( @@ -21,13 +21,13 @@ async def generate_image( - api_url: str, api_key: str, user_prompt: str, api_type: str + api_url: str, creds: OpenAICreds, user_prompt: str ) -> Any: async with aiohttp.ClientSession() as session: async with session.post( api_url, json={"prompt": user_prompt, "response_format": "b64_json"}, - headers=get_auth_header(api_type, api_key), + headers=get_auth_headers(creds), ) as response: status_code = response.status @@ -114,10 +114,9 @@ async def move_attachments_data_to_storage( async def chat_completion( data: Any, upstream_endpoint: str, - api_key: str, + creds: OpenAICreds, is_stream: bool, file_storage: Optional[FileStorage], - api_type: str, api_version: str, ) -> Response: if data.get("n", 1) > 1: @@ -129,9 +128,7 @@ async def chat_completion( api_url = f"{upstream_endpoint}?api-version={api_version}" user_prompt = get_user_prompt(data) - model_response = await generate_image( - api_url, api_key, user_prompt, api_type - ) + model_response = await generate_image(api_url, creds, user_prompt) if isinstance(model_response, JSONResponse): return model_response diff --git a/aidial_adapter_openai/databricks.py b/aidial_adapter_openai/databricks.py index 84d6577..98cb800 100644 --- a/aidial_adapter_openai/databricks.py +++ b/aidial_adapter_openai/databricks.py @@ -1,12 +1,15 @@ -from typing import Any, AsyncIterator, cast +from typing import Any from fastapi.responses import StreamingResponse -from openai import ChatCompletion -from openai.openai_object import OpenAIObject +from openai import AsyncStream +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from aidial_adapter_openai.constant import DEFAULT_TIMEOUT +from aidial_adapter_openai.utils.auth import OpenAICreds from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.parsers import chat_completions_parser +from aidial_adapter_openai.utils.reflection import call_with_extra_body from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream from aidial_adapter_openai.utils.streaming import map_stream @@ -17,33 +20,22 @@ def debug_print(chunk): async def chat_completion( - data: Any, - deployment_id: str, - upstream_endpoint: str, - api_key: str, - api_type: str, + data: Any, upstream_endpoint: str, creds: OpenAICreds ): - request_args = chat_completions_parser.parse( - upstream_endpoint - ).prepare_request_args(deployment_id) - - response = await ChatCompletion().acreate( - api_type=api_type, - api_key=api_key, - request_timeout=DEFAULT_TIMEOUT, - **(data | request_args), + client = chat_completions_parser.parse(upstream_endpoint).get_client( + {**creds, "timeout": DEFAULT_TIMEOUT} ) - if isinstance(response, AsyncIterator): - response = cast(AsyncIterator[OpenAIObject], response) + response: AsyncStream[ChatCompletionChunk] | ChatCompletion = ( + await call_with_extra_body(client.chat.completions.create, data) + ) + if isinstance(response, AsyncStream): return StreamingResponse( to_openai_sse_stream( map_stream( debug_print, - map_stream( - lambda chunk: chunk.to_dict_recursive(), response - ), + map_stream(lambda chunk: chunk.to_dict(), response), ) ), media_type="text/event-stream", diff --git a/aidial_adapter_openai/gpt.py b/aidial_adapter_openai/gpt.py new file mode 100644 index 0000000..2544e84 --- /dev/null +++ b/aidial_adapter_openai/gpt.py @@ -0,0 +1,81 @@ +from aidial_sdk.exceptions import HTTPException as DialException +from fastapi.responses import StreamingResponse +from openai import AsyncStream +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk + +from aidial_adapter_openai.constant import DEFAULT_TIMEOUT +from aidial_adapter_openai.utils.auth import OpenAICreds +from aidial_adapter_openai.utils.log_config import logger +from aidial_adapter_openai.utils.parsers import chat_completions_parser +from aidial_adapter_openai.utils.reflection import call_with_extra_body +from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream +from aidial_adapter_openai.utils.streaming import generate_stream, map_stream +from aidial_adapter_openai.utils.tokens import Tokenizer, discard_messages + + +def debug_print(chunk): + logger.debug(f"chunk: {chunk}") + return chunk + + +async def gpt_chat_completion( + data: dict, + deployment_id: str, + upstream_endpoint: str, + creds: OpenAICreds, + api_version: str, + tokenizer: Tokenizer, +): + discarded_messages = None + if "max_prompt_tokens" in data: + max_prompt_tokens = data["max_prompt_tokens"] + if not isinstance(max_prompt_tokens, int): + raise DialException( + f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'", + 400, + "invalid_request_error", + ) + if max_prompt_tokens < 1: + raise DialException( + f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'", + 400, + "invalid_request_error", + ) + del data["max_prompt_tokens"] + + data["messages"], discarded_messages = discard_messages( + tokenizer, data["messages"], max_prompt_tokens + ) + + client = chat_completions_parser.parse(upstream_endpoint).get_client( + {**creds, "api_version": api_version, "timeout": DEFAULT_TIMEOUT} + ) + + response: AsyncStream[ChatCompletionChunk] | ChatCompletion = ( + await call_with_extra_body(client.chat.completions.create, data) + ) + + if isinstance(response, AsyncStream): + prompt_tokens = tokenizer.calculate_prompt_tokens(data["messages"]) + return StreamingResponse( + to_openai_sse_stream( + map_stream( + debug_print, + generate_stream( + prompt_tokens, + map_stream(lambda obj: obj.to_dict(), response), + tokenizer, + deployment_id, + discarded_messages, + ), + ) + ), + media_type="text/event-stream", + ) + else: + resp = response.to_dict() + if discarded_messages is not None: + resp |= {"statistics": {"discarded_messages": discarded_messages}} + debug_print(resp) + return resp diff --git a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py index 51cf046..aca1ddf 100644 --- a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py +++ b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py @@ -21,7 +21,7 @@ from aidial_adapter_openai.gpt4_multi_modal.gpt4_vision import ( convert_gpt4v_to_gpt4_chunk, ) -from aidial_adapter_openai.utils.auth import get_auth_header +from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.sse_stream import ( parse_openai_sse_stream, @@ -113,10 +113,9 @@ async def gpt4o_chat_completion( request: Any, deployment: str, upstream_endpoint: str, - api_key: str, + creds: OpenAICreds, is_stream: bool, file_storage: Optional[FileStorage], - api_type: str, api_version: str, tokenizer: Tokenizer, ) -> Response: @@ -124,10 +123,9 @@ async def gpt4o_chat_completion( request, deployment, upstream_endpoint, - api_key, + creds, is_stream, file_storage, - api_type, api_version, tokenizer, lambda x: x, @@ -139,20 +137,18 @@ async def gpt4_vision_chat_completion( request: Any, deployment: str, upstream_endpoint: str, - api_key: str, + creds: OpenAICreds, is_stream: bool, file_storage: Optional[FileStorage], - api_type: str, api_version: str, ) -> Response: return await chat_completion( request, deployment, upstream_endpoint, - api_key, + creds, is_stream, file_storage, - api_type, api_version, Tokenizer("gpt-4"), convert_gpt4v_to_gpt4_chunk, @@ -164,10 +160,9 @@ async def chat_completion( request: Any, deployment: str, upstream_endpoint: str, - api_key: str, + creds: OpenAICreds, is_stream: bool, file_storage: Optional[FileStorage], - api_type: str, api_version: str, tokenizer: Tokenizer, response_transformer: Callable[[dict], dict | None], @@ -219,7 +214,7 @@ async def chat_completion( "messages": new_messages, } - headers = get_auth_header(api_type, api_key) + headers = get_auth_headers(creds) if is_stream: response = await predict_stream(api_url, headers, request) diff --git a/aidial_adapter_openai/mistral.py b/aidial_adapter_openai/mistral.py index c4b4d2e..795053a 100644 --- a/aidial_adapter_openai/mistral.py +++ b/aidial_adapter_openai/mistral.py @@ -1,37 +1,46 @@ -from typing import Any, AsyncIterator, cast +from typing import Any from fastapi.responses import StreamingResponse -from openai import ChatCompletion -from openai.openai_object import OpenAIObject +from openai import AsyncOpenAI, AsyncStream +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from aidial_adapter_openai.constant import DEFAULT_TIMEOUT -from aidial_adapter_openai.utils.sse_stream import END_CHUNK, format_chunk +from aidial_adapter_openai.utils.auth import OpenAICreds +from aidial_adapter_openai.utils.log_config import logger +from aidial_adapter_openai.utils.reflection import call_with_extra_body +from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream +from aidial_adapter_openai.utils.streaming import map_stream -async def generate_stream( - stream: AsyncIterator[OpenAIObject], -) -> AsyncIterator[str]: - async for chunk in stream: - yield format_chunk(chunk.to_dict_recursive()) - yield END_CHUNK +def debug_print(chunk): + logger.debug(f"chunk: {chunk}") + return chunk -async def chat_completion(data: Any, upstream_endpoint: str, api_key: str): - data["model"] = "azureai" +async def chat_completion( + data: Any, upstream_endpoint: str, creds: OpenAICreds +): - response = await ChatCompletion().acreate( - api_key=api_key, - api_base=upstream_endpoint, - api_type="openai", - request_timeout=DEFAULT_TIMEOUT, - **data, + client = AsyncOpenAI( + base_url=upstream_endpoint, + timeout=DEFAULT_TIMEOUT, + api_key=creds.get("api_key"), ) - if isinstance(response, AsyncIterator): - response = cast(AsyncIterator[OpenAIObject], response) + response: AsyncStream[ChatCompletionChunk] | ChatCompletion = ( + await call_with_extra_body(client.chat.completions.create, data) + ) + if isinstance(response, AsyncStream): return StreamingResponse( - generate_stream(response), media_type="text/event-stream" + to_openai_sse_stream( + map_stream( + debug_print, + map_stream(lambda chunk: chunk.to_dict(), response), + ) + ), + media_type="text/event-stream", ) else: return response diff --git a/aidial_adapter_openai/openai_override.py b/aidial_adapter_openai/openai_override.py deleted file mode 100644 index 6a3dd7b..0000000 --- a/aidial_adapter_openai/openai_override.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -OpenAI SDK translates various HTTP errors received from OpenAI API -into Python exceptions: error.RateLimitError, error.InvalidRequestError, -error.AuthenticationError etc. - -We want to retranslate the original HTTP errors to the user. -So the standard error handlers in the openai.api_requestor.APIRequestor class -are rewritten to wrap the original HTTP errors into OpenAIException and raise it. -""" - -import json -from json import JSONDecodeError - -import wrapt -from openai.api_requestor import APIRequestor -from openai.openai_response import OpenAIResponse - - -class OpenAIException(Exception): - def __init__(self, body, code, resp, headers): - self.body = body - self.code = code - self.resp = resp - self.headers = headers - - super().__init__(resp) - - -# Overridden to proxy original errors -def handle_error_response_wrapper(wrapped, self, args, kwargs): - raise OpenAIException(*args) - - -# Overridden to proxy original errors -def interpret_response_line_wrapper(wrapped, self: APIRequestor, args, kwargs): - rbody, rcode, rheaders = args - stream = kwargs.get("stream", False) - - # HTTP 204 response code does not have any content in the body. - if rcode == 204: - return OpenAIResponse(None, rheaders) - - if rcode == 503: - raise self.handle_error_response( # overridden - rbody, rcode, None, rheaders, stream_error=False - ) - try: - if "text/plain" in rheaders.get("Content-Type", ""): - data = rbody - else: - data = json.loads(rbody) - except (JSONDecodeError, UnicodeDecodeError): - raise self.handle_error_response( # overridden - rbody, - rcode, - None, - rheaders, - stream_error=False, - ) - resp = OpenAIResponse(data, rheaders) - # In the future, we might add a "status" parameter to errors - # to better handle the "error while streaming" case. - stream_error = stream and "error" in resp.data - if stream_error or not 200 <= rcode < 300: - raise self.handle_error_response( - rbody, rcode, resp.data, rheaders, stream_error=stream_error - ) - return resp - - -wrapt.wrap_function_wrapper( - APIRequestor, "handle_error_response", handle_error_response_wrapper -) - -wrapt.wrap_function_wrapper( - APIRequestor, "_interpret_response_line", interpret_response_line_wrapper -) diff --git a/aidial_adapter_openai/utils/auth.py b/aidial_adapter_openai/utils/auth.py index e2294e3..907f39c 100644 --- a/aidial_adapter_openai/utils/auth.py +++ b/aidial_adapter_openai/utils/auth.py @@ -1,18 +1,15 @@ import os import time -from typing import Mapping, Optional +from typing import Mapping, Optional, TypedDict from aidial_sdk.exceptions import HTTPException as DialException from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.identity.aio import DefaultAzureCredential from fastapi import Request -from openai import util -from openai.util import ApiType from pydantic import BaseModel from aidial_adapter_openai.utils.log_config import logger -from aidial_adapter_openai.utils.parsers import EndpointParser default_credential = DefaultAzureCredential() access_token: AccessToken | None = None @@ -47,26 +44,27 @@ async def get_api_key() -> str: return access_token.token -async def get_credentials( - request: Request, parser: EndpointParser -) -> tuple[str, str]: +class OpenAICreds(TypedDict, total=False): + api_key: str + azure_ad_token: str + + +async def get_credentials(request: Request) -> OpenAICreds: api_key = request.headers.get("X-UPSTREAM-KEY") if api_key is None: - return "azure_ad", await get_api_key() + return {"azure_ad_token": await get_api_key()} + else: + return {"api_key": api_key} + - try: - api_type = parser.parse( - request.headers["X-UPSTREAM-ENDPOINT"] - ).get_api_type() - return api_type, api_key - except Exception: - # fallback to the 'azure' api-type if the endpoint - # doesn't follow the expected format - return "azure", api_key +def get_auth_headers(creds: OpenAICreds) -> dict[str, str]: + if "api_key" in creds: + return {"api-key": creds["api_key"]} + if "azure_ad_token" in creds: + return {"Authorization": f"Bearer {creds['azure_ad_token']}"} -def get_auth_header(api_type: str, api_key: str) -> dict[str, str]: - return util.api_key_to_header(ApiType.from_str(api_type), api_key) + raise ValueError("Invalid credentials") class Auth(BaseModel): diff --git a/aidial_adapter_openai/utils/exceptions.py b/aidial_adapter_openai/utils/exceptions.py deleted file mode 100644 index 472abfd..0000000 --- a/aidial_adapter_openai/utils/exceptions.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Any - - -def create_error(message: str, type: str, param: Any = None, code: Any = None): - return { - "error": { - "message": message, - "type": type, - "param": param, - "code": code, - } - } diff --git a/aidial_adapter_openai/utils/parsers.py b/aidial_adapter_openai/utils/parsers.py index 7d26974..73d4a04 100644 --- a/aidial_adapter_openai/utils/parsers.py +++ b/aidial_adapter_openai/utils/parsers.py @@ -1,42 +1,51 @@ import re from abc import ABC, abstractmethod from json import JSONDecodeError -from typing import Any, Dict, List +from typing import Any, Dict, List, TypedDict from aidial_sdk.exceptions import HTTPException as DialException from fastapi import Request +from openai import AsyncAzureOpenAI, AsyncOpenAI, Timeout from pydantic import BaseModel -class Endpoint(ABC): - @abstractmethod - def prepare_request_args(self, deployment_id: str) -> Dict[str, str]: - pass +class OpenAIParams(TypedDict, total=False): + api_key: str + azure_ad_token: str + api_version: str + timeout: Timeout + +class Endpoint(ABC): @abstractmethod - def get_api_type(self) -> str: + def get_client(self, params: OpenAIParams) -> AsyncOpenAI: pass class AzureOpenAIEndpoint(BaseModel): - api_base: str - deployment_id: str - - def prepare_request_args(self, deployment_id: str) -> Dict[str, str]: - return {"api_base": self.api_base, "engine": self.deployment_id} - - def get_api_type(self) -> str: - return "azure" + azure_endpoint: str + azure_deployment: str + + def get_client(self, params: OpenAIParams) -> AsyncAzureOpenAI: + return AsyncAzureOpenAI( + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + api_key=params.get("api_key"), + azure_ad_token=params.get("azure_ad_token"), + api_version=params.get("api_version"), + timeout=params.get("timeout"), + ) class OpenAIEndpoint(BaseModel): - api_base: str + base_url: str - def prepare_request_args(self, deployment_id: str) -> Dict[str, str]: - return {"api_base": self.api_base, "model": deployment_id} - - def get_api_type(self) -> str: - return "open_ai" + def get_client(self, params: OpenAIParams) -> AsyncOpenAI: + return AsyncOpenAI( + base_url=self.base_url, + api_key=params.get("api_key"), + timeout=params.get("timeout"), + ) class EndpointParser(BaseModel): @@ -49,13 +58,14 @@ def parse(self, endpoint: str) -> AzureOpenAIEndpoint | OpenAIEndpoint: if match: return AzureOpenAIEndpoint( - api_base=match[1], deployment_id=match[2] + azure_endpoint=match[1], + azure_deployment=match[2], ) match = re.search(f"(.+?)/{self.name}", endpoint) if match: - return OpenAIEndpoint(api_base=match[1]) + return OpenAIEndpoint(base_url=match[1]) raise DialException( "Invalid upstream endpoint format", 400, "invalid_request_error" diff --git a/aidial_adapter_openai/utils/reflection.py b/aidial_adapter_openai/utils/reflection.py new file mode 100644 index 0000000..916a0f7 --- /dev/null +++ b/aidial_adapter_openai/utils/reflection.py @@ -0,0 +1,44 @@ +import inspect +from typing import Any, Callable, Coroutine, TypeVar + +from aidial_sdk.exceptions import HTTPException as DialException + +T = TypeVar("T") + + +async def call_with_extra_body( + func: Callable[..., Coroutine[Any, Any, T]], arg: dict +) -> T: + if has_kwargs_argument(func): + return await func(**arg) + + expected_args = set(inspect.signature(func).parameters.keys()) + actual_args = set(arg.keys()) + + extra_args = actual_args - expected_args + + if extra_args and "extra_body" not in expected_args: + raise DialException( + f"Extra arguments aren't supported: {extra_args}.", + 400, + "invalid_request_error", + ) + + arg["extra_body"] = arg.get("extra_body") or {} + + for extra_arg in extra_args: + arg["extra_body"][extra_arg] = arg[extra_arg] + del arg[extra_arg] + + return await func(**arg) + + +def has_kwargs_argument(func: Callable[..., Coroutine[Any, Any, Any]]) -> bool: + """ + Determines if the given function accepts a variable keyword argument (**kwargs). + """ + signature = inspect.signature(func) + for param in signature.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return False diff --git a/aidial_adapter_openai/utils/sse_stream.py b/aidial_adapter_openai/utils/sse_stream.py index a78cf56..9ca07a3 100644 --- a/aidial_adapter_openai/utils/sse_stream.py +++ b/aidial_adapter_openai/utils/sse_stream.py @@ -1,7 +1,7 @@ import json from typing import Any, AsyncIterator, Mapping -from aidial_adapter_openai.utils.exceptions import create_error +from aidial_sdk.utils.errors import json_error DATA_PREFIX = "data: " OPENAI_END_MARKER = "[DONE]" @@ -24,7 +24,7 @@ async def parse_openai_sse_stream( try: payload = line.decode("utf-8-sig").lstrip() # type: ignore except Exception: - yield create_error( + yield json_error( message="Can't decode chunk to a string", type="runtime_error" ) return @@ -33,7 +33,7 @@ async def parse_openai_sse_stream( continue if not payload.startswith(DATA_PREFIX): - yield create_error( + yield json_error( message="Invalid chunk format", type="runtime_error" ) return @@ -46,7 +46,7 @@ async def parse_openai_sse_stream( try: chunk = json.loads(payload) except json.JSONDecodeError: - yield create_error( + yield json_error( message="Can't parse chunk to JSON", type="runtime_error" ) return diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index 4f22857..d1aae2c 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -2,10 +2,11 @@ from typing import Any, AsyncIterator, Callable, Optional, TypeVar from uuid import uuid4 +from aidial_sdk.utils.errors import json_error from aidial_sdk.utils.merge_chunks import merge from fastapi.responses import JSONResponse, Response, StreamingResponse +from openai import APIError -from aidial_adapter_openai.openai_override import OpenAIException from aidial_adapter_openai.utils.env import get_env_bool from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.sse_stream import END_CHUNK, format_chunk @@ -65,6 +66,9 @@ async def generate_stream( choice = chunk["choices"][0] + content = (choice.get("delta") or {}).get("content") or "" + total_content += content + if choice["finish_reason"] is not None: stream_finished = True completion_tokens = tokenizer.calculate_tokens( @@ -79,9 +83,6 @@ async def generate_stream( chunk["statistics"] = { "discarded_messages": discarded_messages } - else: - content = choice["delta"].get("content") or "" - total_content += content yield chunk else: @@ -91,8 +92,13 @@ async def generate_stream( yield chunk last_chunk = chunk - except OpenAIException as e: - yield e.body + except APIError as e: + yield json_error( + message=e.message, + type=e.type, + param=e.param, + code=e.code, + ) return if not stream_finished: diff --git a/poetry.lock b/poetry.lock index dc243c0..183c4de 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aidial-sdk" @@ -131,20 +131,6 @@ yarl = ">=1.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns", "brotlicffi"] -[[package]] -name = "aioresponses" -version = "0.7.6" -description = "Mock out requests made by ClientSession from aiohttp package" -optional = false -python-versions = "*" -files = [ - {file = "aioresponses-0.7.6-py2.py3-none-any.whl", hash = "sha256:d2c26defbb9b440ea2685ec132e90700907fd10bcca3e85ec2f157219f0d26f7"}, - {file = "aioresponses-0.7.6.tar.gz", hash = "sha256:f795d9dbda2d61774840e7e32f5366f45752d1adc1b74c9362afd017296c7ee1"}, -] - -[package.dependencies] -aiohttp = ">=3.3.0,<4.0.0" - [[package]] name = "aiosignal" version = "1.3.1" @@ -628,6 +614,17 @@ files = [ {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "fastapi" version = "0.109.2" @@ -1165,25 +1162,26 @@ files = [ [[package]] name = "openai" -version = "0.28.1" -description = "Python client library for the OpenAI API" +version = "1.33.0" +description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-0.28.1-py3-none-any.whl", hash = "sha256:d18690f9e3d31eedb66b57b88c2165d760b24ea0a01f150dd3f068155088ce68"}, - {file = "openai-0.28.1.tar.gz", hash = "sha256:4be1dad329a65b4ce1a660fe6d5431b438f429b5855c883435f0f7fcb6d2dcc8"}, + {file = "openai-1.33.0-py3-none-any.whl", hash = "sha256:621163b56570897ab8389d187f686a53d4771fd6ce95d481c0a9611fe8bc4229"}, + {file = "openai-1.33.0.tar.gz", hash = "sha256:1169211a7b326ecbc821cafb427c29bfd0871f9a3e0947dd9e51acb3b0f1df78"}, ] [package.dependencies] -aiohttp = "*" -requests = ">=2.20" -tqdm = "*" +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" [package.extras] -datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] -dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"] -embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] -wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "opentelemetry-api" @@ -1869,22 +1867,6 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] -[[package]] -name = "pytest-aioresponses" -version = "0.2.0" -description = "py.test integration for aioresponses" -optional = false -python-versions = ">=3.6,<4.0" -files = [ - {file = "pytest-aioresponses-0.2.0.tar.gz", hash = "sha256:61cced206857cb4e7aab10b61600527f505c358d046e7d3ad3ae09455d02d937"}, - {file = "pytest_aioresponses-0.2.0-py3-none-any.whl", hash = "sha256:1a78d1eb76e1bffe7adc83a1bad0d48c373b41289367ae1f5e7ec0fceb60a04d"}, -] - -[package.dependencies] -aioresponses = ">=0.7.1,<0.8.0" -pytest = ">=3.5.0" -pytest-asyncio = ">=0.14.0" - [[package]] name = "pytest-asyncio" version = "0.21.1" @@ -2039,13 +2021,13 @@ files = [ [[package]] name = "requests" -version = "2.32.0" +version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.0-py3-none-any.whl", hash = "sha256:f2c3881dddb70d056c5bd7600a4fae312b2a300e39be6a118d30b90bd27262b5"}, - {file = "requests-2.32.0.tar.gz", hash = "sha256:fa5490319474c82ef1d2c9bc459d3652e3ae4ef4c4ebdd18a21145a47ca4b6b8"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -2058,6 +2040,20 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "respx" +version = "0.21.1" +description = "A utility for mocking out the Python HTTPX and HTTP Core libraries." +optional = false +python-versions = ">=3.7" +files = [ + {file = "respx-0.21.1-py2.py3-none-any.whl", hash = "sha256:05f45de23f0c785862a2c92a3e173916e8ca88e4caad715dd5f68584d6053c20"}, + {file = "respx-0.21.1.tar.gz", hash = "sha256:0bd7fe21bfaa52106caa1223ce61224cf30786985f17c63c5d71eff0307ee8af"}, +] + +[package.dependencies] +httpx = ">=0.21.0" + [[package]] name = "setuptools" version = "68.2.2" @@ -2440,4 +2436,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "34b081f03bcb5e6a44d758df4029cbb97250f07b30349e0724107cfcdaa14f4f" +content-hash = "de645b9905e10080c3ae3f9343a8c1001a87046425152dfcf93633db228a7233" diff --git a/pyproject.toml b/pyproject.toml index 8b35cca..8f02b5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ env_files = [".env"] [tool.poetry.dependencies] python = ">=3.11,<3.13" fastapi = "0.109.2" -openai = "0.28.1" +openai = "1.33.0" tiktoken = "0.7.0" uvicorn = "0.23" wrapt = "^1.15.0" @@ -35,9 +35,9 @@ httpx = "^0.25.0" # TODO: remove once SDK supports conditional instrumentation [tool.poetry.group.test.dependencies] pytest = "7.4.0" +pytest-asyncio = "0.21.1" python-dotenv = "1.0.0" -pytest-aioresponses = "^0.2.0" -aioresponses = "^0.7.6" +respx = "^0.21.1" [tool.poetry.group.lint.dependencies] pyright = "1.1.324" @@ -49,6 +49,12 @@ flake8 = "6.0.0" [tool.poetry.group.dev.dependencies] nox = "^2023.4.22" +[tool.pytest.ini_options] +# muting warnings coming from opentelemetry package +filterwarnings = [ + "ignore::DeprecationWarning:opentelemetry.instrumentation.dependencies" +] + [tool.pyright] typeCheckingMode = "basic" reportUnusedVariable = "error" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d3e51af --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import httpx +import pytest_asyncio + +from aidial_adapter_openai.app import app + + +@pytest_asyncio.fixture +async def test_app(): + async with httpx.AsyncClient( + app=app, base_url="http://test-app.com" + ) as client: + yield client diff --git a/tests/test_errors.py b/tests/test_errors.py index 1d944fe..7840b02 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,53 +1,63 @@ import json +from typing import Callable +import httpx import pytest -from aioresponses import aioresponses -from httpx import AsyncClient +import respx +from respx.types import SideEffectTypes -from aidial_adapter_openai.app import app +from tests.utils.stream import OpenAIStream -@pytest.mark.asyncio -async def test_error_during_streaming(aioresponses: aioresponses): - aioresponses.post( - "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", - status=200, - body="data: " - + json.dumps( - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1695940483, - "model": "gpt-4", - "choices": [ - { - "index": 0, - "finish_reason": "stop", - "message": { - "role": "assistant", - }, - } - ], - "usage": None, - } +def assert_equal(actual, expected): + assert actual == expected + + +def mock_response( + status_code: int, + content_type: str, + content: str, + check_request: Callable[[httpx.Request], None] = lambda _: None, +) -> SideEffectTypes: + def side_effect(request: httpx.Request): + check_request(request) + return httpx.Response( + status_code=status_code, + headers={"content-type": content_type}, + content=content, ) - + "\n\n" - + "data: " - + json.dumps( - { - "error": { - "message": "Error test", - "type": "runtime_error", - "param": None, - "code": None, + + return side_effect + + +@respx.mock +@pytest.mark.asyncio +async def test_single_chunk_token_counting(test_app: httpx.AsyncClient): + # The adapter tolerates top-level extra fields + # and passes it further to the upstream endpoint. + + mock_stream = OpenAIStream( + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1695940483, + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "delta": {"role": "assistant", "content": "5"}, } - } - ) - + "\n\n" - + "data: [DONE]\n\n", + ], + }, + ) + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=200, content_type="text/event-stream", + content=mock_stream.to_content(), ) - test_app = AsyncClient(app=app, base_url="http://test.com") response = await test_app.post( "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", @@ -62,107 +72,234 @@ async def test_error_during_streaming(aioresponses: aioresponses): ) assert response.status_code == 200 + mock_stream.assert_response_content( + response, + assert_equal, + usages={ + 0: { + "prompt_tokens": 9, + "completion_tokens": 1, + "total_tokens": 10, + } + }, + ) + + +@respx.mock +@pytest.mark.asyncio +async def test_top_level_extra_field(test_app: httpx.AsyncClient): + # The adapter tolerates top-level extra fields + # and passes it further to the upstream endpoint. + + mock_stream = OpenAIStream({"error": {"message": "whatever"}}) - for index, line in enumerate(response.iter_lines()): - if index % 2 == 1: - assert line == "" - continue - - if index == 0: - assert ( - line - == 'data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1695940483,"model":"gpt-4","choices":[{"index":0,"finish_reason":"stop","message":{"role":"assistant"}}],"usage":{"completion_tokens":0,"prompt_tokens":9,"total_tokens":9}}' - ) - elif index == 2: - assert ( - line - == 'data: {"error": {"message": "Error test", "type": "runtime_error", "param": null, "code": null}}' - ) - elif index == 4: - assert line == "data: [DONE]" - else: - assert False + def check_request(request: httpx.Request): + assert json.loads(request.content)["extra_field"] == 1 + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).mock( + side_effect=mock_response( + status_code=200, + content_type="text/event-stream", + content=mock_stream.to_content(), + check_request=check_request, + ), + ) + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={ + "messages": [{"role": "user", "content": "Test content"}], + "stream": True, + "extra_field": 1, + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + mock_stream.assert_response_content(response, assert_equal) + + +@respx.mock @pytest.mark.asyncio -async def test_incorrect_upstream_url(aioresponses: aioresponses): - aioresponses.post( - "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", - status=200, - body={}, +async def test_nested_extra_field(test_app: httpx.AsyncClient): + # The adapter tolerates nested extra fields + # and passes it further to the upstream endpoint. + + mock_stream = OpenAIStream({"error": {"message": "whatever"}}) + + def check_request(request: httpx.Request): + assert json.loads(request.content)["messages"][0]["extra_field"] == 1 + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).mock( + side_effect=mock_response( + status_code=200, + content_type="text/event-stream", + content=mock_stream.to_content(), + check_request=check_request, + ), ) - test_app = AsyncClient(app=app, base_url="http://test.com") response = await test_app.post( "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", - json={"messages": [{"role": "user", "content": "Test content"}]}, + json={ + "messages": [ + {"role": "user", "content": "2+3=?", "extra_field": 1} + ], + "stream": True, + }, headers={ "X-UPSTREAM-KEY": "TEST_API_KEY", - "X-UPSTREAM-ENDPOINT": "http://localhost:5001", # upstream endpoint should contain the full path + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + mock_stream.assert_response_content(response, assert_equal) + + +@respx.mock +@pytest.mark.asyncio +async def test_missing_api_version(test_app: httpx.AsyncClient): + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions", + json={ + "messages": [{"role": "user", "content": "Test content"}], + "stream": True, + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", }, ) assert response.status_code == 400 assert response.json() == { "error": { - "message": "Invalid upstream endpoint format", + "message": "api-version is a required query parameter", "type": "invalid_request_error", - "param": None, - "code": None, - "display_message": None, } } +@respx.mock @pytest.mark.asyncio -async def test_incorrect_format(aioresponses: aioresponses): - aioresponses.post( - "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", - status=400, - body="Incorrect format", +async def test_error_during_streaming(test_app: httpx.AsyncClient): + mock_stream = OpenAIStream( + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1695940483, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "delta": {"role": "assistant"}, + } + ], + "usage": None, + }, + { + "error": { + "message": "Error test", + "type": "runtime_error", + } + }, + ) + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=200, + content_type="text/event-stream", + content=mock_stream.to_content(), ) - test_app = AsyncClient(app=app, base_url="http://test.com") response = await test_app.post( "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", - json={"messages": [{"role": "user", "content": "Test content"}]}, + json={ + "messages": [{"role": "user", "content": "Test content"}], + "stream": True, + }, headers={ "X-UPSTREAM-KEY": "TEST_API_KEY", "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", }, ) - assert response.status_code == 400 + assert response.status_code == 200 + mock_stream.assert_response_content( + response, + assert_equal, + usages={ + 0: { + "prompt_tokens": 9, + "completion_tokens": 0, + "total_tokens": 9, + } + }, + ) + + +@respx.mock +@pytest.mark.asyncio +async def test_incorrect_upstream_url(test_app: httpx.AsyncClient): + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={"messages": [{"role": "user", "content": "Test content"}]}, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + # upstream endpoint should contain the full path + "X-UPSTREAM-ENDPOINT": "http://localhost:5001", + }, + ) - assert response.content == b"Incorrect format" + assert response.status_code == 400 + assert response.json() == { + "error": { + "message": "Invalid upstream endpoint format", + "type": "invalid_request_error", + } + } +@respx.mock @pytest.mark.asyncio -async def test_incorrect_streaming_request(aioresponses: aioresponses): - aioresponses.post( - "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", - status=400, - body=json.dumps( - { - "error": { - "message": "0 is less than the minimum of 1 - 'n'", - "type": "invalid_request_error", - "param": None, - "code": None, - "display_message": None, - } - } - ), - content_type="application/json", +async def test_correct_upstream_url(test_app: httpx.AsyncClient): + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond(status_code=400, content="whatever") + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={"messages": [{"role": "user", "content": "Test content"}]}, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, ) - test_app = AsyncClient(app=app, base_url="http://test.com") + assert response.status_code == 400 + assert response.content == b"whatever" + + +@respx.mock +@pytest.mark.asyncio +async def test_incorrect_streaming_request(test_app: httpx.AsyncClient): response = await test_app.post( "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", json={ "messages": [{"role": "user", "content": "Test content"}], "stream": True, - "n": 0, + "max_prompt_tokens": 0, }, headers={ "X-UPSTREAM-KEY": "TEST_API_KEY", @@ -170,13 +307,12 @@ async def test_incorrect_streaming_request(aioresponses: aioresponses): }, ) - assert response.status_code == 400 - assert response.json() == { + expected_response = { "error": { - "message": "0 is less than the minimum of 1 - 'n'", + "message": "'0' is less than the minimum of 1 - 'max_prompt_tokens'", "type": "invalid_request_error", - "param": None, - "code": None, - "display_message": None, } } + + assert response.status_code == 400 + assert response.json() == expected_response diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 427daf1..41e8c74 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1,73 +1,67 @@ -import json - +import httpx import pytest -from aioresponses import aioresponses -from httpx import AsyncClient +import respx + +from tests.utils.stream import OpenAIStream -from aidial_adapter_openai.app import app +def assert_equal(actual, expected): + assert actual == expected + +@respx.mock @pytest.mark.asyncio -async def test_streaming(aioresponses: aioresponses): - aioresponses.post( - "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-06-15", - status=200, - body="data: " - + json.dumps( - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1695940483, - "model": "gpt-4", - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": { - "role": "assistant", - }, - } - ], - "usage": None, - } - ) - + "\n\n" - + "data: " - + json.dumps( - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1695940483, - "model": "gpt-4", - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": { - "content": "Test content", - }, - } - ], - "usage": None, - } - ) - + "\n\n" - + "data: " - + json.dumps( - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1696245654, - "model": "gpt-4", - "choices": [{"index": 0, "finish_reason": "stop", "delta": {}}], - "usage": None, - } - ) - + "\n\n" - + "data: [DONE]\n\n", +async def test_streaming(test_app: httpx.AsyncClient): + mock_stream = OpenAIStream( + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1695940483, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": { + "role": "assistant", + }, + } + ], + "usage": None, + }, + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1695940483, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": { + "content": "Test content", + }, + } + ], + "usage": None, + }, + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1696245654, + "model": "gpt-4", + "choices": [{"index": 0, "finish_reason": "stop", "delta": {}}], + "usage": None, + }, + ) + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-06-15" + ).respond( + status_code=200, + content=mock_stream.to_content(), content_type="text/event-stream", ) - test_app = AsyncClient(app=app, base_url="http://test.com") response = await test_app.post( "/openai/deployments/gpt-4/chat/completions?api-version=2023-06-15", @@ -82,28 +76,14 @@ async def test_streaming(aioresponses: aioresponses): ) assert response.status_code == 200 - - for index, line in enumerate(response.iter_lines()): - if index % 2 == 1: - assert line == "" - continue - - if index == 0: - assert ( - line - == 'data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1695940483,"model":"gpt-4","choices":[{"index":0,"finish_reason":null,"delta":{"role":"assistant"}}],"usage":null}' - ) - elif index == 2: - assert ( - line - == 'data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1695940483,"model":"gpt-4","choices":[{"index":0,"finish_reason":null,"delta":{"content":"Test content"}}],"usage":null}' - ) - elif index == 4: - assert ( - line - == 'data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1696245654,"model":"gpt-4","choices":[{"index":0,"finish_reason":"stop","delta":{}}],"usage":{"completion_tokens":2,"prompt_tokens":9,"total_tokens":11}}' - ) - elif index == 6: - assert line == "data: [DONE]" - else: - assert False + mock_stream.assert_response_content( + response, + assert_equal, + usages={ + 2: { + "completion_tokens": 2, + "prompt_tokens": 9, + "total_tokens": 11, + } + }, + ) diff --git a/tests/utils/stream.py b/tests/utils/stream.py new file mode 100644 index 0000000..2c046d3 --- /dev/null +++ b/tests/utils/stream.py @@ -0,0 +1,45 @@ +import json +from typing import Any, Callable, List + +import httpx + + +class OpenAIStream: + chunks: List[dict] + + def __init__(self, *chunks: dict): + self.chunks = list(chunks) + + def to_content(self) -> str: + ret = "" + for chunk in self.chunks: + ret += f"data: {json.dumps(chunk)}\n\n" + ret += "data: [DONE]\n\n" + return ret + + def assert_response_content( + self, + response: httpx.Response, + assert_equality: Callable[[Any, Any], None], + usages: dict[int, dict] = {}, + ): + line_idx = 0 + for line in response.iter_lines(): + chunk_idx = line_idx // 2 + + if line_idx % 2 == 1: + assert_equality(line, "") + + elif chunk_idx < len(self.chunks): + chunk = self.chunks[chunk_idx] + if chunk_idx in usages: + chunk = chunk | {"usage": usages[chunk_idx]} + assert_equality(json.loads(line.removeprefix("data: ")), chunk) + + elif chunk_idx == len(self.chunks): + assert_equality(line, "data: [DONE]") + + else: + assert False + + line_idx += 1