From a848f52f41ba13e6a387b3b1ddb866c4a945fe37 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Mon, 7 Oct 2024 13:42:19 +0200 Subject: [PATCH] feat: supported streaming emulation (#157) --- README.md | 1 + aidial_adapter_openai/app.py | 152 ++++++++++-------- aidial_adapter_openai/completions.py | 28 ++-- aidial_adapter_openai/dalle3.py | 28 ++-- aidial_adapter_openai/databricks.py | 7 +- aidial_adapter_openai/env.py | 3 + aidial_adapter_openai/gpt.py | 33 ++-- .../gpt4_multi_modal/chat_completion.py | 42 ++--- aidial_adapter_openai/mistral.py | 8 +- aidial_adapter_openai/utils/streaming.py | 45 ++++++ poetry.lock | 16 +- pyproject.toml | 4 - tests/test_errors.py | 95 ++++++++++- 13 files changed, 281 insertions(+), 181 deletions(-) diff --git a/README.md b/README.md index 427f09c..63a04ac 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ Copy `.env.example` to `.env` and customize it for your environment: |MISTRAL_DEPLOYMENTS|``|Comma-separated list of deployments that support Mistral Large Azure API. Example: `mistral-large-azure,mistral-large`| |DATABRICKS_DEPLOYMENTS|``|Comma-separated list of Databricks chat completion deployments. Example: `databricks-dbrx-instruct,databricks-mixtral-8x7b-instruct,databricks-llama-2-70b-chat`| |GPT4O_DEPLOYMENTS|``|Comma-separated list of GPT-4o chat completion deployments. Example: `gpt-4o-2024-05-13`| +|NON_STREAMING_DEPLOYMENTS|``|Comma-separated list of deployments which do not support streaming. The adapter is going to emulate the streaming by calling the model and converting its response into a single-chunk stream. Example: `o1-mini`, `o1-preview`| ### Docker diff --git a/aidial_adapter_openai/app.py b/aidial_adapter_openai/app.py index bc12e3e..d9f9fa2 100644 --- a/aidial_adapter_openai/app.py +++ b/aidial_adapter_openai/app.py @@ -1,5 +1,4 @@ from contextlib import asynccontextmanager -from typing import Awaitable, TypeVar from aidial_sdk.exceptions import HTTPException as DialException from aidial_sdk.exceptions import InvalidRequestError @@ -7,7 +6,13 @@ from aidial_sdk.telemetry.types import TelemetryConfig from fastapi import FastAPI, Request from fastapi.responses import Response -from openai import APIConnectionError, APIStatusError, APITimeoutError +from openai import ( + APIConnectionError, + APIError, + APIStatusError, + APITimeoutError, + OpenAIError, +) from aidial_adapter_openai.completions import chat_completion as completion from aidial_adapter_openai.dalle3 import ( @@ -25,6 +30,7 @@ GPT4O_DEPLOYMENTS, MISTRAL_DEPLOYMENTS, MODEL_ALIASES, + NON_STREAMING_DEPLOYMENTS, ) from aidial_adapter_openai.gpt import gpt_chat_completion from aidial_adapter_openai.gpt4_multi_modal.chat_completion import ( @@ -44,6 +50,7 @@ ) from aidial_adapter_openai.utils.reflection import call_with_extra_body from aidial_adapter_openai.utils.storage import create_file_storage +from aidial_adapter_openai.utils.streaming import create_server_response from aidial_adapter_openai.utils.tokenizer import ( MultiModalTokenizer, PlainTextTokenizer, @@ -63,34 +70,6 @@ async def lifespan(app: FastAPI): init_telemetry(app, TelemetryConfig()) configure_loggers() -T = TypeVar("T") - - -async def handle_exceptions(call: Awaitable[T]) -> T | Response: - try: - return await call - except APIStatusError as e: - r = e.response - return Response( - content=r.content, - status_code=r.status_code, - headers=r.headers, - ) - except APITimeoutError: - raise DialException( - status_code=504, - type="timeout", - message="Request timed out", - display_message="Request timed out. Please try again later.", - ) - except APIConnectionError: - raise DialException( - status_code=502, - type="connection", - message="Error communicating with OpenAI", - display_message="OpenAI server is not responsive. Please try again later.", - ) - def get_api_version(request: Request): api_version = request.query_params.get("api-version", "") @@ -104,8 +83,26 @@ def get_api_version(request: Request): @app.post("/openai/deployments/{deployment_id:path}/chat/completions") async def chat_completion(deployment_id: str, request: Request): + data = await parse_body(request) + is_stream = bool(data.get("stream")) + + emulate_streaming = deployment_id in NON_STREAMING_DEPLOYMENTS and is_stream + + if emulate_streaming: + data["stream"] = False + + return create_server_response( + emulate_streaming, + await call_chat_completion(deployment_id, data, is_stream, request), + ) + + +async def call_chat_completion( + deployment_id: str, data: dict, is_stream: bool, request: Request +): + # Azure OpenAI deployments ignore "model" request field, # since the deployment id is already encoded in the endpoint path. # This is not the case for non-Azure OpenAI deployments, so @@ -116,22 +113,18 @@ async def chat_completion(deployment_id: str, request: Request): # The same goes for /embeddings endpoint. data["model"] = deployment_id - is_stream = data.get("stream", False) - creds = await get_credentials(request) api_version = get_api_version(request) upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] if completions_endpoint := completions_parser.parse(upstream_endpoint): - return await handle_exceptions( - completion( - data, - completions_endpoint, - creds, - api_version, - deployment_id, - ) + return await completion( + data, + completions_endpoint, + creds, + api_version, + deployment_id, ) if deployment_id in DALLE3_DEPLOYMENTS: @@ -146,14 +139,10 @@ async def chat_completion(deployment_id: str, request: Request): ) if deployment_id in MISTRAL_DEPLOYMENTS: - return await handle_exceptions( - mistral_chat_completion(data, upstream_endpoint, creds) - ) + return await mistral_chat_completion(data, upstream_endpoint, creds) if deployment_id in DATABRICKS_DEPLOYMENTS: - return await handle_exceptions( - databricks_chat_completion(data, upstream_endpoint, creds) - ) + return await databricks_chat_completion(data, upstream_endpoint, creds) if deployment_id in GPT4_VISION_DEPLOYMENTS: storage = create_file_storage("images", request.headers) @@ -171,29 +160,25 @@ async def chat_completion(deployment_id: str, request: Request): if deployment_id in GPT4O_DEPLOYMENTS: tokenizer = MultiModalTokenizer(openai_model_name) storage = create_file_storage("images", request.headers) - return await handle_exceptions( - gpt4o_chat_completion( - data, - deployment_id, - upstream_endpoint, - creds, - is_stream, - storage, - api_version, - tokenizer, - ) - ) - - tokenizer = PlainTextTokenizer(model=openai_model_name) - return await handle_exceptions( - gpt_chat_completion( + return await gpt4o_chat_completion( data, deployment_id, upstream_endpoint, creds, + is_stream, + storage, api_version, tokenizer, ) + + tokenizer = PlainTextTokenizer(model=openai_model_name) + return await gpt_chat_completion( + data, + deployment_id, + upstream_endpoint, + creds, + api_version, + tokenizer, ) @@ -212,13 +197,48 @@ async def embedding(deployment_id: str, request: Request): {**creds, "api_version": api_version} ) - return await handle_exceptions( - call_with_extra_body(client.embeddings.create, data) - ) + return await call_with_extra_body(client.embeddings.create, data) + + +@app.exception_handler(OpenAIError) +def openai_exception_handler(request: Request, e: DialException): + if isinstance(e, APIStatusError): + r = e.response + return Response( + content=r.content, + status_code=r.status_code, + headers=r.headers, + ) + + if isinstance(e, APITimeoutError): + raise DialException( + status_code=504, + type="timeout", + message="Request timed out", + display_message="Request timed out. Please try again later.", + ) + + if isinstance(e, APIConnectionError): + raise DialException( + status_code=502, + type="connection", + message="Error communicating with OpenAI", + display_message="OpenAI server is not responsive. Please try again later.", + ) + + if isinstance(e, APIError): + raise DialException( + status_code=getattr(e, "status_code", None) or 500, + message=e.message, + type=e.type, + code=e.code, + param=e.param, + display_message=None, + ) @app.exception_handler(DialException) -def exception_handler(request: Request, exc: DialException): +def dial_exception_handler(request: Request, exc: DialException): return exc.to_fastapi_response() diff --git a/aidial_adapter_openai/completions.py b/aidial_adapter_openai/completions.py index 8d2b1f5..90834b5 100644 --- a/aidial_adapter_openai/completions.py +++ b/aidial_adapter_openai/completions.py @@ -1,7 +1,6 @@ from typing import Any, Dict from aidial_sdk.exceptions import RequestValidationError -from fastapi.responses import JSONResponse, StreamingResponse from openai import AsyncStream from openai.types import Completion @@ -12,7 +11,6 @@ OpenAIEndpoint, ) from aidial_adapter_openai.utils.reflection import call_with_extra_body -from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream from aidial_adapter_openai.utils.streaming import ( build_chunk, debug_print, @@ -48,15 +46,15 @@ async def chat_completion( creds: OpenAICreds, api_version: str, deployment_id: str, -) -> Any: +): - if data.get("n", 1) > 1: # type: ignore + if data.get("n") or 1 > 1: raise RequestValidationError("The deployment doesn't support n > 1") client = endpoint.get_client({**creds, "api_version": api_version}) - messages = data.get("messages", []) - if len(messages) == 0: + messages = data.get("messages") or [] + if not messages: raise RequestValidationError("The request doesn't contain any messages") prompt = messages[-1].get("content") or "" @@ -67,24 +65,18 @@ async def chat_completion( prompt = template.format(prompt=prompt) del data["messages"] + response = await call_with_extra_body( client.completions.create, {"prompt": prompt, **data}, ) if isinstance(response, AsyncStream): - return StreamingResponse( - to_openai_sse_stream( - map_stream( - lambda item: convert_to_chat_completions_response( - item, is_stream=True - ), - response, - ) + return map_stream( + lambda item: convert_to_chat_completions_response( + item, is_stream=True ), - media_type="text/event-stream", + response, ) else: - return JSONResponse( - convert_to_chat_completions_response(response, is_stream=False) - ) + return convert_to_chat_completions_response(response, is_stream=False) diff --git a/aidial_adapter_openai/dalle3.py b/aidial_adapter_openai/dalle3.py index a20a35b..99e1b03 100644 --- a/aidial_adapter_openai/dalle3.py +++ b/aidial_adapter_openai/dalle3.py @@ -3,10 +3,9 @@ import aiohttp from aidial_sdk.exceptions import HTTPException as DIALException from aidial_sdk.exceptions import RequestValidationError -from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi.responses import JSONResponse from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers -from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream from aidial_adapter_openai.utils.storage import FileStorage from aidial_adapter_openai.utils.streaming import build_chunk, generate_id @@ -111,7 +110,7 @@ async def chat_completion( is_stream: bool, file_storage: Optional[FileStorage], api_version: str, -) -> Response: +): if data.get("n", 1) > 1: raise RequestValidationError("The deployment doesn't support n > 1") @@ -133,19 +132,14 @@ async def chat_completion( if file_storage is not None: await move_attachments_data_to_storage(custom_content, file_storage) - if not is_stream: - return JSONResponse( - content=build_chunk( - id, - "stop", - {"role": "assistant", "content": "", **custom_content}, - created, - False, - usage=IMG_USAGE, - ) - ) + if is_stream: + return generate_stream(id, created, custom_content) else: - return StreamingResponse( - to_openai_sse_stream(generate_stream(id, created, custom_content)), - media_type="text/event-stream", + return build_chunk( + id, + "stop", + {"role": "assistant", "content": "", **custom_content}, + created, + False, + usage=IMG_USAGE, ) diff --git a/aidial_adapter_openai/databricks.py b/aidial_adapter_openai/databricks.py index c28d75e..2d02f6b 100644 --- a/aidial_adapter_openai/databricks.py +++ b/aidial_adapter_openai/databricks.py @@ -1,6 +1,5 @@ from typing import Any, cast -from fastapi.responses import StreamingResponse from openai import AsyncStream from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -11,7 +10,6 @@ chat_completions_parser, ) from aidial_adapter_openai.utils.reflection import call_with_extra_body -from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream from aidial_adapter_openai.utils.streaming import chunk_to_dict, map_stream @@ -27,9 +25,6 @@ async def chat_completion( ) if isinstance(response, AsyncStream): - return StreamingResponse( - to_openai_sse_stream(map_stream(chunk_to_dict, response)), - media_type="text/event-stream", - ) + return map_stream(chunk_to_dict, response) else: return response diff --git a/aidial_adapter_openai/env.py b/aidial_adapter_openai/env.py index 40f9bdd..85302ed 100644 --- a/aidial_adapter_openai/env.py +++ b/aidial_adapter_openai/env.py @@ -23,6 +23,9 @@ os.getenv("COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES") or "{}" ) DALLE3_AZURE_API_VERSION = os.getenv("DALLE3_AZURE_API_VERSION", "2024-02-01") +NON_STREAMING_DEPLOYMENTS = parse_deployment_list( + os.getenv("NON_STREAMING_DEPLOYMENTS") +) def get_eliminate_empty_choices() -> bool: diff --git a/aidial_adapter_openai/gpt.py b/aidial_adapter_openai/gpt.py index 83a3770..5c141fb 100644 --- a/aidial_adapter_openai/gpt.py +++ b/aidial_adapter_openai/gpt.py @@ -1,7 +1,6 @@ -from typing import List, Tuple, cast +from typing import AsyncIterator, List, Tuple, cast from aidial_sdk.exceptions import InvalidRequestError -from fastapi.responses import StreamingResponse from openai import AsyncStream from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -9,7 +8,6 @@ from aidial_adapter_openai.utils.auth import OpenAICreds from aidial_adapter_openai.utils.parsers import chat_completions_parser from aidial_adapter_openai.utils.reflection import call_with_extra_body -from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream from aidial_adapter_openai.utils.streaming import ( chunk_to_dict, debug_print, @@ -73,23 +71,18 @@ async def gpt_chat_completion( await call_with_extra_body(client.chat.completions.create, data) ) - if isinstance(response, AsyncStream): - return StreamingResponse( - to_openai_sse_stream( - generate_stream( - get_prompt_tokens=lambda: prompt_tokens - or tokenizer.calculate_prompt_tokens(data["messages"]), - tokenize=tokenizer.calculate_text_tokens, - deployment=deployment_id, - discarded_messages=discarded_messages, - stream=map_stream(chunk_to_dict, response), - ), - ), - media_type="text/event-stream", + if isinstance(response, AsyncIterator): + return generate_stream( + get_prompt_tokens=lambda: prompt_tokens + or tokenizer.calculate_prompt_tokens(data["messages"]), + tokenize=tokenizer.calculate_text_tokens, + deployment=deployment_id, + discarded_messages=discarded_messages, + stream=map_stream(chunk_to_dict, response), ) else: - resp = response.to_dict() + rest = response.to_dict() if discarded_messages is not None: - resp |= {"statistics": {"discarded_messages": discarded_messages}} - debug_print("response", resp) - return resp + rest |= {"statistics": {"discarded_messages": discarded_messages}} + debug_print("response", rest) + return rest diff --git a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py index 33ee422..8b2851d 100644 --- a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py +++ b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py @@ -14,7 +14,7 @@ import aiohttp from aidial_sdk.exceptions import HTTPException as DialException from aidial_sdk.exceptions import RequestValidationError -from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi.responses import JSONResponse, Response from aidial_adapter_openai.gpt4_multi_modal.attachment import ( SUPPORTED_FILE_EXTS, @@ -28,10 +28,7 @@ from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.multi_modal_message import MultiModalMessage -from aidial_adapter_openai.utils.sse_stream import ( - parse_openai_sse_stream, - to_openai_sse_stream, -) +from aidial_adapter_openai.utils.sse_stream import parse_openai_sse_stream from aidial_adapter_openai.utils.storage import FileStorage from aidial_adapter_openai.utils.streaming import ( create_response_from_chunk, @@ -145,7 +142,7 @@ async def gpt4o_chat_completion( file_storage: Optional[FileStorage], api_version: str, tokenizer: MultiModalTokenizer, -) -> Response: +): return await chat_completion( request, deployment, @@ -168,7 +165,7 @@ async def gpt4_vision_chat_completion( is_stream: bool, file_storage: Optional[FileStorage], api_version: str, -) -> Response: +): return await chat_completion( request, deployment, @@ -194,7 +191,7 @@ async def chat_completion( tokenizer: MultiModalTokenizer, response_transformer: Callable[[dict], dict | None], default_max_tokens: Optional[int], -) -> Response: +): if request.get("n", 1) > 1: raise RequestValidationError("The deployment doesn't support n > 1") @@ -252,23 +249,18 @@ def debug_print(chunk: T) -> T: logger.debug(f"chunk: {chunk}") return chunk - return StreamingResponse( - to_openai_sse_stream( - map_stream( - debug_print, - generate_stream( - get_prompt_tokens=lambda: estimated_prompt_tokens, - tokenize=tokenizer.calculate_text_tokens, - deployment=deployment, - discarded_messages=discarded_messages, - stream=map_stream( - response_transformer, - parse_openai_sse_stream(response), - ), - ), - ) + return map_stream( + debug_print, + generate_stream( + get_prompt_tokens=lambda: estimated_prompt_tokens, + tokenize=tokenizer.calculate_text_tokens, + deployment=deployment, + discarded_messages=discarded_messages, + stream=map_stream( + response_transformer, + parse_openai_sse_stream(response), + ), ), - media_type="text/event-stream", ) else: response = await predict_non_stream(api_url, headers, request) @@ -304,4 +296,4 @@ def debug_print(chunk: T) -> T: f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})" ) - return JSONResponse(content=response) + return response diff --git a/aidial_adapter_openai/mistral.py b/aidial_adapter_openai/mistral.py index cab8304..0fe727f 100644 --- a/aidial_adapter_openai/mistral.py +++ b/aidial_adapter_openai/mistral.py @@ -1,6 +1,5 @@ from typing import Any -from fastapi.responses import StreamingResponse from openai import AsyncOpenAI, AsyncStream from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -8,14 +7,12 @@ from aidial_adapter_openai.utils.auth import OpenAICreds from aidial_adapter_openai.utils.http_client import get_http_client from aidial_adapter_openai.utils.reflection import call_with_extra_body -from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream from aidial_adapter_openai.utils.streaming import chunk_to_dict, map_stream async def chat_completion( data: Any, upstream_endpoint: str, creds: OpenAICreds ): - client = AsyncOpenAI( base_url=upstream_endpoint, api_key=creds.get("api_key"), @@ -27,9 +24,6 @@ async def chat_completion( ) if isinstance(response, AsyncStream): - return StreamingResponse( - to_openai_sse_stream(map_stream(chunk_to_dict, response)), - media_type="text/event-stream", - ) + return map_stream(chunk_to_dict, response) else: return response diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index b37f325..038df0b 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from openai import APIError, APIStatusError from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from pydantic import BaseModel from aidial_adapter_openai.env import get_eliminate_empty_choices from aidial_adapter_openai.utils.log_config import logger @@ -218,6 +219,50 @@ async def generator() -> AsyncIterator[dict]: ) +def block_response_to_streaming_chunk(response: dict) -> dict: + response["object"] = "chat.completion.chunk" + for choice in response.get("choices") or []: + if message := choice.get("message"): + choice["delta"] = message + del choice["message"] + return response + + +def create_server_response( + emulate_stream: bool, + response: AsyncIterator[dict] | dict | BaseModel | Response, +) -> Response: + + def block_to_stream(block: dict) -> AsyncIterator[dict]: + async def stream(): + yield block_response_to_streaming_chunk(block) + + return stream() + + def stream_to_response(stream: AsyncIterator[dict]) -> Response: + return StreamingResponse( + to_openai_sse_stream(stream), + media_type="text/event-stream", + ) + + def block_to_response(block: dict) -> Response: + if emulate_stream: + return stream_to_response(block_to_stream(block)) + else: + return JSONResponse(response) + + if isinstance(response, AsyncIterator): + return stream_to_response(response) + + if isinstance(response, dict): + return block_to_response(response) + + if isinstance(response, BaseModel): + return block_to_response(response.dict()) + + return response + + T = TypeVar("T") V = TypeVar("V") diff --git a/poetry.lock b/poetry.lock index b6d9f7f..729eb58 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1886,20 +1886,6 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] -[[package]] -name = "python-dotenv" -version = "1.0.0" -description = "Read key-value pairs from a .env file and set them as environment variables" -optional = false -python-versions = ">=3.8" -files = [ - {file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"}, - {file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"}, -] - -[package.extras] -cli = ["click (>=5.0)"] - [[package]] name = "pywin32" version = "306" @@ -2436,4 +2422,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "e63a91778f4b0c6b001c143a407108d11ccace9a4d63e056826d55613bd5ffb0" +content-hash = "8fc2dc8ce0ef702221685aa6631e01c92ae0c7ad05a08cfa2b0994e757cf4ba3" diff --git a/pyproject.toml b/pyproject.toml index 4f62215..8ed3929 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,9 +16,6 @@ repository = "https://github.com/epam/ai-dial-adapter-openai" [tool.poetry.scripts] clean = "scripts.clean:main" -[pytest] -env_files = [".env"] - [tool.poetry.dependencies] python = ">=3.11,<3.13" fastapi = "0.109.2" @@ -37,7 +34,6 @@ aidial-sdk = {version = "^0.13.0", extras = ["telemetry"]} [tool.poetry.group.test.dependencies] pytest = "7.4.0" pytest-asyncio = "0.21.1" -python-dotenv = "1.0.0" respx = "^0.21.1" [tool.poetry.group.lint.dependencies] diff --git a/tests/test_errors.py b/tests/test_errors.py index 3ccf236..dcdafce 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -397,10 +397,45 @@ async def test_incorrect_upstream_url(test_app: httpx.AsyncClient): @respx.mock @pytest.mark.asyncio -async def test_correct_upstream_url(test_app: httpx.AsyncClient): +async def test_no_request_response_validation(test_app: httpx.AsyncClient): respx.post( "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" - ).respond(status_code=400, content="whatever") + ).respond( + status_code=200, json={"messages": "string", "extra_response": "string"} + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={ + "messages": [ + { + "role": "user", + "content": "Test content", + "extra_mesage": "string", + } + ], + "extra_request": "string", + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + "Content-Type": "application/pdf", + }, + ) + + assert response.status_code == 200 + assert response.json() == { + "messages": "string", + "extra_response": "string", + } + + +@respx.mock +@pytest.mark.asyncio +async def test_status_error_from_upstream(test_app: httpx.AsyncClient): + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond(status_code=400, content="Bad request") response = await test_app.post( "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", @@ -412,7 +447,61 @@ async def test_correct_upstream_url(test_app: httpx.AsyncClient): ) assert response.status_code == 400 - assert response.content == b"whatever" + assert response.content == b"Bad request" + + +@respx.mock +@pytest.mark.asyncio +async def test_timeout_error_from_upstream(test_app: httpx.AsyncClient): + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).mock(side_effect=httpx.ReadTimeout("Timeout error")) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={"messages": [{"role": "user", "content": "Test content"}]}, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 504 + assert response.json() == { + "error": { + "message": "Request timed out", + "type": "timeout", + "code": "504", + "display_message": "Request timed out. Please try again later.", + } + } + + +@respx.mock +@pytest.mark.asyncio +async def test_connection_error_from_upstream(test_app: httpx.AsyncClient): + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).mock(side_effect=httpx.ConnectError("Connection error")) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={"messages": [{"role": "user", "content": "Test content"}]}, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 502 + assert response.json() == { + "error": { + "message": "Error communicating with OpenAI", + "type": "connection", + "code": "502", + "display_message": "OpenAI server is not responsive. Please try again later.", + } + } @respx.mock