Skip to content

Commit

Permalink
feat: migrated to openai==1.33.0 (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jun 17, 2024
1 parent 2f7e978 commit 4fe7215
Show file tree
Hide file tree
Showing 21 changed files with 730 additions and 556 deletions.
175 changes: 64 additions & 111 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
import os
from typing import Dict
from typing import Awaitable, Dict, TypeVar

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.telemetry.init import init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
from aidial_sdk.utils.errors import json_error
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import ChatCompletion, Embedding, error
from openai.openai_object import OpenAIObject
from fastapi.responses import JSONResponse, Response
from openai import APIConnectionError, APIStatusError, APITimeoutError

from aidial_adapter_openai.constant import DEFAULT_TIMEOUT
from aidial_adapter_openai.dalle3 import (
Expand All @@ -17,26 +17,24 @@
from aidial_adapter_openai.databricks import (
chat_completion as databricks_chat_completion,
)
from aidial_adapter_openai.gpt import gpt_chat_completion
from aidial_adapter_openai.gpt4_multi_modal.chat_completion import (
gpt4_vision_chat_completion,
gpt4o_chat_completion,
)
from aidial_adapter_openai.mistral import (
chat_completion as mistral_chat_completion,
)
from aidial_adapter_openai.openai_override import OpenAIException
from aidial_adapter_openai.utils.auth import get_credentials
from aidial_adapter_openai.utils.log_config import configure_loggers
from aidial_adapter_openai.utils.parsers import (
chat_completions_parser,
embeddings_parser,
parse_body,
parse_deployment_list,
)
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.storage import create_file_storage
from aidial_adapter_openai.utils.streaming import generate_stream, map_stream
from aidial_adapter_openai.utils.tokens import Tokenizer, discard_messages
from aidial_adapter_openai.utils.tokens import Tokenizer

app = FastAPI()

Expand All @@ -62,20 +60,27 @@
)
dalle3_azure_api_version = os.getenv("DALLE3_AZURE_API_VERSION", "2024-02-01")

T = TypeVar("T")

async def handle_exceptions(call):

async def handle_exceptions(call: Awaitable[T]) -> T | Response:
try:
return await call
except OpenAIException as e:
return Response(status_code=e.code, headers=e.headers, content=e.body)
except error.Timeout:
except APIStatusError as e:
r = e.response
return Response(
content=r.content,
status_code=r.status_code,
headers=r.headers,
)
except APITimeoutError:
raise DialException(
"Request timed out",
504,
"timeout",
display_message="Request timed out. Please try again later.",
)
except error.APIConnectionError:
except APIConnectionError:
raise DialException(
"Error communicating with OpenAI",
502,
Expand All @@ -90,7 +95,7 @@ def get_api_version(request: Request):

if api_version == "":
raise DialException(
"Api version is a required query parameter",
"api-version is a required query parameter",
400,
"invalid_request_error",
)
Expand All @@ -102,9 +107,19 @@ def get_api_version(request: Request):
async def chat_completion(deployment_id: str, request: Request):
data = await parse_body(request)

# Azure OpenAI deployments ignore "model" request field,
# since the deployment id is already encoded in the endpoint path.
# This is not the case for non-Azure OpenAI deployments, so
# they require the "model" field to be set.
# However, openai==1.33.0 requires the "model" field for **both**
# Azure and non-Azure deployments.
# Therefore, we provide the "model" field for all deployments here.
# The same goes for /embeddings endpoint.
data["model"] = deployment_id

is_stream = data.get("stream", False)

api_type, api_key = await get_credentials(request, chat_completions_parser)
creds = await get_credentials(request)

upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]

Expand All @@ -113,25 +128,20 @@ async def chat_completion(deployment_id: str, request: Request):
return await dalle3_chat_completion(
data,
upstream_endpoint,
api_key,
creds,
is_stream,
storage,
api_type,
dalle3_azure_api_version,
)
elif deployment_id in mistral_deployments:

if deployment_id in mistral_deployments:
return await handle_exceptions(
mistral_chat_completion(data, upstream_endpoint, api_key)
mistral_chat_completion(data, upstream_endpoint, creds)
)
elif deployment_id in databricks_deployments:

if deployment_id in databricks_deployments:
return await handle_exceptions(
databricks_chat_completion(
data,
deployment_id,
upstream_endpoint,
api_key,
api_type,
)
databricks_chat_completion(data, upstream_endpoint, creds)
)

api_version = get_api_version(request)
Expand All @@ -142,10 +152,9 @@ async def chat_completion(deployment_id: str, request: Request):
data,
deployment_id,
upstream_endpoint,
api_key,
creds,
is_stream,
storage,
api_type,
api_version,
)

Expand All @@ -159,113 +168,57 @@ async def chat_completion(deployment_id: str, request: Request):
data,
deployment_id,
upstream_endpoint,
api_key,
creds,
is_stream,
storage,
api_type,
api_version,
tokenizer,
)
)

discarded_messages = None
if "max_prompt_tokens" in data:
max_prompt_tokens = data["max_prompt_tokens"]
if not isinstance(max_prompt_tokens, int):
raise DialException(
f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'",
400,
"invalid_request_error",
)
if max_prompt_tokens < 1:
raise DialException(
f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'",
400,
"invalid_request_error",
)
del data["max_prompt_tokens"]

data["messages"], discarded_messages = discard_messages(
tokenizer, data["messages"], max_prompt_tokens
)

request_args = chat_completions_parser.parse(
upstream_endpoint
).prepare_request_args(deployment_id)

response = await handle_exceptions(
ChatCompletion().acreate(
api_key=api_key,
api_type=api_type,
api_version=api_version,
request_timeout=DEFAULT_TIMEOUT,
**(data | request_args),
return await handle_exceptions(
gpt_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
api_version,
tokenizer,
)
)

if isinstance(response, Response):
return response

if is_stream:
prompt_tokens = tokenizer.calculate_prompt_tokens(data["messages"])
chunk_stream = map_stream(lambda obj: obj.to_dict_recursive(), response)
return StreamingResponse(
to_openai_sse_stream(
generate_stream(
prompt_tokens,
chunk_stream,
tokenizer,
deployment_id,
discarded_messages,
)
),
media_type="text/event-stream",
)
else:
if discarded_messages is not None:
assert isinstance(response, OpenAIObject)
response = response.to_dict() | {
"statistics": {"discarded_messages": discarded_messages}
}

return response


@app.post("/openai/deployments/{deployment_id}/embeddings")
async def embedding(deployment_id: str, request: Request):
data = await parse_body(request)

api_type, api_key = await get_credentials(request, embeddings_parser)
# See note for /chat/completions endpoint
data["model"] = deployment_id

creds = await get_credentials(request)
api_version = get_api_version(request)
upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]

request_args = embeddings_parser.parse(
request.headers["X-UPSTREAM-ENDPOINT"]
).prepare_request_args(deployment_id)
client = embeddings_parser.parse(upstream_endpoint).get_client(
{**creds, "api_version": api_version, "timeout": DEFAULT_TIMEOUT}
)

return await handle_exceptions(
Embedding().acreate(
api_key=api_key,
api_type=api_type,
api_version=api_version,
request_timeout=DEFAULT_TIMEOUT,
**(data | request_args),
)
call_with_extra_body(client.embeddings.create, data)
)


@app.exception_handler(DialException)
def exception_handler(request: Request, exc: DialException):
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"message": exc.message,
"type": exc.type,
"param": exc.param,
"code": exc.code,
"display_message": exc.display_message,
}
},
content=json_error(
message=exc.message,
type=exc.type,
param=exc.param,
code=exc.code,
display_message=exc.display_message,
),
)


Expand Down
5 changes: 4 additions & 1 deletion aidial_adapter_openai/constant.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
DEFAULT_TIMEOUT = (10, 600) # connect timeout and total timeout
from openai import Timeout

# connect timeout and total timeout
DEFAULT_TIMEOUT = Timeout(600, connect=10)
13 changes: 5 additions & 8 deletions aidial_adapter_openai/dalle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aidial_sdk.exceptions import HTTPException as DialException
from fastapi.responses import JSONResponse, Response, StreamingResponse

from aidial_adapter_openai.utils.auth import get_auth_header
from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
from aidial_adapter_openai.utils.sse_stream import END_CHUNK
from aidial_adapter_openai.utils.storage import FileStorage
from aidial_adapter_openai.utils.streaming import (
Expand All @@ -21,13 +21,13 @@


async def generate_image(
api_url: str, api_key: str, user_prompt: str, api_type: str
api_url: str, creds: OpenAICreds, user_prompt: str
) -> Any:
async with aiohttp.ClientSession() as session:
async with session.post(
api_url,
json={"prompt": user_prompt, "response_format": "b64_json"},
headers=get_auth_header(api_type, api_key),
headers=get_auth_headers(creds),
) as response:
status_code = response.status

Expand Down Expand Up @@ -114,10 +114,9 @@ async def move_attachments_data_to_storage(
async def chat_completion(
data: Any,
upstream_endpoint: str,
api_key: str,
creds: OpenAICreds,
is_stream: bool,
file_storage: Optional[FileStorage],
api_type: str,
api_version: str,
) -> Response:
if data.get("n", 1) > 1:
Expand All @@ -129,9 +128,7 @@ async def chat_completion(

api_url = f"{upstream_endpoint}?api-version={api_version}"
user_prompt = get_user_prompt(data)
model_response = await generate_image(
api_url, api_key, user_prompt, api_type
)
model_response = await generate_image(api_url, creds, user_prompt)

if isinstance(model_response, JSONResponse):
return model_response
Expand Down
Loading

0 comments on commit 4fe7215

Please sign in to comment.