Skip to content

Commit

Permalink
feat: supported streaming emulation (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Oct 7, 2024
1 parent 3bc8706 commit a848f52
Show file tree
Hide file tree
Showing 13 changed files with 281 additions and 181 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Copy `.env.example` to `.env` and customize it for your environment:
|MISTRAL_DEPLOYMENTS|``|Comma-separated list of deployments that support Mistral Large Azure API. Example: `mistral-large-azure,mistral-large`|
|DATABRICKS_DEPLOYMENTS|``|Comma-separated list of Databricks chat completion deployments. Example: `databricks-dbrx-instruct,databricks-mixtral-8x7b-instruct,databricks-llama-2-70b-chat`|
|GPT4O_DEPLOYMENTS|``|Comma-separated list of GPT-4o chat completion deployments. Example: `gpt-4o-2024-05-13`|
|NON_STREAMING_DEPLOYMENTS|``|Comma-separated list of deployments which do not support streaming. The adapter is going to emulate the streaming by calling the model and converting its response into a single-chunk stream. Example: `o1-mini`, `o1-preview`|

### Docker

Expand Down
152 changes: 86 additions & 66 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from contextlib import asynccontextmanager
from typing import Awaitable, TypeVar

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.telemetry.init import init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import FastAPI, Request
from fastapi.responses import Response
from openai import APIConnectionError, APIStatusError, APITimeoutError
from openai import (
APIConnectionError,
APIError,
APIStatusError,
APITimeoutError,
OpenAIError,
)

from aidial_adapter_openai.completions import chat_completion as completion
from aidial_adapter_openai.dalle3 import (
Expand All @@ -25,6 +30,7 @@
GPT4O_DEPLOYMENTS,
MISTRAL_DEPLOYMENTS,
MODEL_ALIASES,
NON_STREAMING_DEPLOYMENTS,
)
from aidial_adapter_openai.gpt import gpt_chat_completion
from aidial_adapter_openai.gpt4_multi_modal.chat_completion import (
Expand All @@ -44,6 +50,7 @@
)
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.storage import create_file_storage
from aidial_adapter_openai.utils.streaming import create_server_response
from aidial_adapter_openai.utils.tokenizer import (
MultiModalTokenizer,
PlainTextTokenizer,
Expand All @@ -63,34 +70,6 @@ async def lifespan(app: FastAPI):
init_telemetry(app, TelemetryConfig())
configure_loggers()

T = TypeVar("T")


async def handle_exceptions(call: Awaitable[T]) -> T | Response:
try:
return await call
except APIStatusError as e:
r = e.response
return Response(
content=r.content,
status_code=r.status_code,
headers=r.headers,
)
except APITimeoutError:
raise DialException(
status_code=504,
type="timeout",
message="Request timed out",
display_message="Request timed out. Please try again later.",
)
except APIConnectionError:
raise DialException(
status_code=502,
type="connection",
message="Error communicating with OpenAI",
display_message="OpenAI server is not responsive. Please try again later.",
)


def get_api_version(request: Request):
api_version = request.query_params.get("api-version", "")
Expand All @@ -104,8 +83,26 @@ def get_api_version(request: Request):

@app.post("/openai/deployments/{deployment_id:path}/chat/completions")
async def chat_completion(deployment_id: str, request: Request):

data = await parse_body(request)

is_stream = bool(data.get("stream"))

emulate_streaming = deployment_id in NON_STREAMING_DEPLOYMENTS and is_stream

if emulate_streaming:
data["stream"] = False

return create_server_response(
emulate_streaming,
await call_chat_completion(deployment_id, data, is_stream, request),
)


async def call_chat_completion(
deployment_id: str, data: dict, is_stream: bool, request: Request
):

# Azure OpenAI deployments ignore "model" request field,
# since the deployment id is already encoded in the endpoint path.
# This is not the case for non-Azure OpenAI deployments, so
Expand All @@ -116,22 +113,18 @@ async def chat_completion(deployment_id: str, request: Request):
# The same goes for /embeddings endpoint.
data["model"] = deployment_id

is_stream = data.get("stream", False)

creds = await get_credentials(request)
api_version = get_api_version(request)

upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]

if completions_endpoint := completions_parser.parse(upstream_endpoint):
return await handle_exceptions(
completion(
data,
completions_endpoint,
creds,
api_version,
deployment_id,
)
return await completion(
data,
completions_endpoint,
creds,
api_version,
deployment_id,
)

if deployment_id in DALLE3_DEPLOYMENTS:
Expand All @@ -146,14 +139,10 @@ async def chat_completion(deployment_id: str, request: Request):
)

if deployment_id in MISTRAL_DEPLOYMENTS:
return await handle_exceptions(
mistral_chat_completion(data, upstream_endpoint, creds)
)
return await mistral_chat_completion(data, upstream_endpoint, creds)

if deployment_id in DATABRICKS_DEPLOYMENTS:
return await handle_exceptions(
databricks_chat_completion(data, upstream_endpoint, creds)
)
return await databricks_chat_completion(data, upstream_endpoint, creds)

if deployment_id in GPT4_VISION_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
Expand All @@ -171,29 +160,25 @@ async def chat_completion(deployment_id: str, request: Request):
if deployment_id in GPT4O_DEPLOYMENTS:
tokenizer = MultiModalTokenizer(openai_model_name)
storage = create_file_storage("images", request.headers)
return await handle_exceptions(
gpt4o_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
storage,
api_version,
tokenizer,
)
)

tokenizer = PlainTextTokenizer(model=openai_model_name)
return await handle_exceptions(
gpt_chat_completion(
return await gpt4o_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
storage,
api_version,
tokenizer,
)

tokenizer = PlainTextTokenizer(model=openai_model_name)
return await gpt_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
api_version,
tokenizer,
)


Expand All @@ -212,13 +197,48 @@ async def embedding(deployment_id: str, request: Request):
{**creds, "api_version": api_version}
)

return await handle_exceptions(
call_with_extra_body(client.embeddings.create, data)
)
return await call_with_extra_body(client.embeddings.create, data)


@app.exception_handler(OpenAIError)
def openai_exception_handler(request: Request, e: DialException):
if isinstance(e, APIStatusError):
r = e.response
return Response(
content=r.content,
status_code=r.status_code,
headers=r.headers,
)

if isinstance(e, APITimeoutError):
raise DialException(
status_code=504,
type="timeout",
message="Request timed out",
display_message="Request timed out. Please try again later.",
)

if isinstance(e, APIConnectionError):
raise DialException(
status_code=502,
type="connection",
message="Error communicating with OpenAI",
display_message="OpenAI server is not responsive. Please try again later.",
)

if isinstance(e, APIError):
raise DialException(
status_code=getattr(e, "status_code", None) or 500,
message=e.message,
type=e.type,
code=e.code,
param=e.param,
display_message=None,
)


@app.exception_handler(DialException)
def exception_handler(request: Request, exc: DialException):
def dial_exception_handler(request: Request, exc: DialException):
return exc.to_fastapi_response()


Expand Down
28 changes: 10 additions & 18 deletions aidial_adapter_openai/completions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict

from aidial_sdk.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from openai import AsyncStream
from openai.types import Completion

Expand All @@ -12,7 +11,6 @@
OpenAIEndpoint,
)
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.streaming import (
build_chunk,
debug_print,
Expand Down Expand Up @@ -48,15 +46,15 @@ async def chat_completion(
creds: OpenAICreds,
api_version: str,
deployment_id: str,
) -> Any:
):

if data.get("n", 1) > 1: # type: ignore
if data.get("n") or 1 > 1:
raise RequestValidationError("The deployment doesn't support n > 1")

client = endpoint.get_client({**creds, "api_version": api_version})

messages = data.get("messages", [])
if len(messages) == 0:
messages = data.get("messages") or []
if not messages:
raise RequestValidationError("The request doesn't contain any messages")

prompt = messages[-1].get("content") or ""
Expand All @@ -67,24 +65,18 @@ async def chat_completion(
prompt = template.format(prompt=prompt)

del data["messages"]

response = await call_with_extra_body(
client.completions.create,
{"prompt": prompt, **data},
)

if isinstance(response, AsyncStream):
return StreamingResponse(
to_openai_sse_stream(
map_stream(
lambda item: convert_to_chat_completions_response(
item, is_stream=True
),
response,
)
return map_stream(
lambda item: convert_to_chat_completions_response(
item, is_stream=True
),
media_type="text/event-stream",
response,
)
else:
return JSONResponse(
convert_to_chat_completions_response(response, is_stream=False)
)
return convert_to_chat_completions_response(response, is_stream=False)
28 changes: 11 additions & 17 deletions aidial_adapter_openai/dalle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import aiohttp
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
from fastapi.responses import JSONResponse

from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.storage import FileStorage
from aidial_adapter_openai.utils.streaming import build_chunk, generate_id

Expand Down Expand Up @@ -111,7 +110,7 @@ async def chat_completion(
is_stream: bool,
file_storage: Optional[FileStorage],
api_version: str,
) -> Response:
):
if data.get("n", 1) > 1:
raise RequestValidationError("The deployment doesn't support n > 1")

Expand All @@ -133,19 +132,14 @@ async def chat_completion(
if file_storage is not None:
await move_attachments_data_to_storage(custom_content, file_storage)

if not is_stream:
return JSONResponse(
content=build_chunk(
id,
"stop",
{"role": "assistant", "content": "", **custom_content},
created,
False,
usage=IMG_USAGE,
)
)
if is_stream:
return generate_stream(id, created, custom_content)
else:
return StreamingResponse(
to_openai_sse_stream(generate_stream(id, created, custom_content)),
media_type="text/event-stream",
return build_chunk(
id,
"stop",
{"role": "assistant", "content": "", **custom_content},
created,
False,
usage=IMG_USAGE,
)
7 changes: 1 addition & 6 deletions aidial_adapter_openai/databricks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, cast

from fastapi.responses import StreamingResponse
from openai import AsyncStream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
Expand All @@ -11,7 +10,6 @@
chat_completions_parser,
)
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.streaming import chunk_to_dict, map_stream


Expand All @@ -27,9 +25,6 @@ async def chat_completion(
)

if isinstance(response, AsyncStream):
return StreamingResponse(
to_openai_sse_stream(map_stream(chunk_to_dict, response)),
media_type="text/event-stream",
)
return map_stream(chunk_to_dict, response)
else:
return response
Loading

0 comments on commit a848f52

Please sign in to comment.