Skip to content

Commit

Permalink
fix: fixing performance regression (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jun 21, 2024
1 parent 0b7a393 commit bb5403e
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 80 deletions.
8 changes: 4 additions & 4 deletions .ort.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ resolutions:
- message: ".*PyPI::tiktoken:0\\.7\\.0.*"
reason: "CANT_FIX_EXCEPTION"
comment: "MIT License: https://github.com/openai/tiktoken/blob/0.7.0/LICENSE"
- message: ".*PyPI::httpcore:0\\.18\\.0.*"
- message: ".*PyPI::httpcore:1\\.0\\.5.*"
reason: "CANT_FIX_EXCEPTION"
comment: "BSD 3-Clause New or Revised License: https://github.com/encode/httpcore/blob/0.18.0/LICENSE.md"
- message: ".*PyPI::httpx:0\\.25\\.0.*"
comment: "BSD 3-Clause New or Revised License: https://github.com/encode/httpcore/blob/1.0.5/LICENSE.md"
- message: ".*PyPI::httpx:0\\.27\\.0.*"
reason: "CANT_FIX_EXCEPTION"
comment: "BSD 3-Clause New or Revised License: https://github.com/encode/httpx/blob/0.25.0/LICENSE.md"
comment: "BSD 3-Clause New or Revised License: https://github.com/encode/httpx/blob/0.27.0/LICENSE.md"
18 changes: 14 additions & 4 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from contextlib import asynccontextmanager
from typing import Awaitable, Dict, TypeVar

from aidial_sdk.exceptions import HTTPException as DialException
Expand All @@ -10,7 +11,6 @@
from fastapi.responses import JSONResponse, Response
from openai import APIConnectionError, APIStatusError, APITimeoutError

from aidial_adapter_openai.constant import DEFAULT_TIMEOUT
from aidial_adapter_openai.dalle3 import (
chat_completion as dalle3_chat_completion,
)
Expand All @@ -26,7 +26,8 @@
chat_completion as mistral_chat_completion,
)
from aidial_adapter_openai.utils.auth import get_credentials
from aidial_adapter_openai.utils.log_config import configure_loggers
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.parsers import (
embeddings_parser,
parse_body,
Expand All @@ -36,7 +37,16 @@
from aidial_adapter_openai.utils.storage import create_file_storage
from aidial_adapter_openai.utils.tokens import Tokenizer

app = FastAPI()

@asynccontextmanager
async def lifespan(app: FastAPI):
yield
logger.info("Application shutdown")
await get_http_client().aclose()


app = FastAPI(lifespan=lifespan)


init_telemetry(app, TelemetryConfig())
configure_loggers()
Expand Down Expand Up @@ -200,7 +210,7 @@ async def embedding(deployment_id: str, request: Request):
upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]

client = embeddings_parser.parse(upstream_endpoint).get_client(
{**creds, "api_version": api_version, "timeout": DEFAULT_TIMEOUT}
{**creds, "api_version": api_version}
)

return await handle_exceptions(
Expand Down
4 changes: 0 additions & 4 deletions aidial_adapter_openai/constant.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
from openai import Timeout

# connect timeout and total timeout
DEFAULT_TIMEOUT = Timeout(600, connect=10)
25 changes: 8 additions & 17 deletions aidial_adapter_openai/databricks.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
from typing import Any
from typing import Any, cast

from fastapi.responses import StreamingResponse
from openai import AsyncStream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from aidial_adapter_openai.constant import DEFAULT_TIMEOUT
from aidial_adapter_openai.utils.auth import OpenAICreds
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.parsers import chat_completions_parser
from aidial_adapter_openai.utils.parsers import (
OpenAIParams,
chat_completions_parser,
)
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.streaming import map_stream


def debug_print(chunk):
logger.debug(f"chunk: {chunk}")
return chunk
from aidial_adapter_openai.utils.streaming import chunk_to_dict, map_stream


async def chat_completion(
data: Any, upstream_endpoint: str, creds: OpenAICreds
):
client = chat_completions_parser.parse(upstream_endpoint).get_client(
{**creds, "timeout": DEFAULT_TIMEOUT}
cast(OpenAIParams, creds)
)

response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
Expand All @@ -32,12 +28,7 @@ async def chat_completion(

if isinstance(response, AsyncStream):
return StreamingResponse(
to_openai_sse_stream(
map_stream(
debug_print,
map_stream(lambda chunk: chunk.to_dict(), response),
)
),
to_openai_sse_stream(map_stream(chunk_to_dict, response)),
media_type="text/event-stream",
)
else:
Expand Down
36 changes: 16 additions & 20 deletions aidial_adapter_openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from aidial_adapter_openai.constant import DEFAULT_TIMEOUT
from aidial_adapter_openai.utils.auth import OpenAICreds
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.parsers import chat_completions_parser
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.streaming import generate_stream, map_stream
from aidial_adapter_openai.utils.streaming import (
chunk_to_dict,
debug_print,
generate_stream,
map_stream,
)
from aidial_adapter_openai.utils.tokens import Tokenizer, discard_messages


def debug_print(chunk):
logger.debug(f"chunk: {chunk}")
return chunk


async def gpt_chat_completion(
data: dict,
deployment_id: str,
Expand Down Expand Up @@ -49,33 +47,31 @@ async def gpt_chat_completion(
)

client = chat_completions_parser.parse(upstream_endpoint).get_client(
{**creds, "api_version": api_version, "timeout": DEFAULT_TIMEOUT}
{**creds, "api_version": api_version}
)

response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
await call_with_extra_body(client.chat.completions.create, data)
)

if isinstance(response, AsyncStream):

prompt_tokens = tokenizer.calculate_prompt_tokens(data["messages"])
return StreamingResponse(
to_openai_sse_stream(
map_stream(
debug_print,
generate_stream(
prompt_tokens,
map_stream(lambda obj: obj.to_dict(), response),
tokenizer,
deployment_id,
discarded_messages,
),
)
generate_stream(
prompt_tokens,
map_stream(chunk_to_dict, response),
tokenizer,
deployment_id,
discarded_messages,
),
),
media_type="text/event-stream",
)
else:
resp = response.to_dict()
if discarded_messages is not None:
resp |= {"statistics": {"discarded_messages": discarded_messages}}
debug_print(resp)
debug_print("response", resp)
return resp
19 changes: 4 additions & 15 deletions aidial_adapter_openai/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,11 @@
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from aidial_adapter_openai.constant import DEFAULT_TIMEOUT
from aidial_adapter_openai.utils.auth import OpenAICreds
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.streaming import map_stream


def debug_print(chunk):
logger.debug(f"chunk: {chunk}")
return chunk
from aidial_adapter_openai.utils.streaming import chunk_to_dict, map_stream


async def chat_completion(
Expand All @@ -24,8 +18,8 @@ async def chat_completion(

client = AsyncOpenAI(
base_url=upstream_endpoint,
timeout=DEFAULT_TIMEOUT,
api_key=creds.get("api_key"),
http_client=get_http_client(),
)

response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
Expand All @@ -34,12 +28,7 @@ async def chat_completion(

if isinstance(response, AsyncStream):
return StreamingResponse(
to_openai_sse_stream(
map_stream(
debug_print,
map_stream(lambda chunk: chunk.to_dict(), response),
)
),
to_openai_sse_stream(map_stream(chunk_to_dict, response)),
media_type="text/event-stream",
)
else:
Expand Down
20 changes: 20 additions & 0 deletions aidial_adapter_openai/utils/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import functools

import httpx

# connect timeout and total timeout
DEFAULT_TIMEOUT = httpx.Timeout(600, connect=10)

# Borrowed from openai._constants.DEFAULT_CONNECTION_LIMITS
DEFAULT_CONNECTION_LIMITS = httpx.Limits(
max_connections=1000, max_keepalive_connections=100
)


@functools.cache
def get_http_client() -> httpx.AsyncClient:
return httpx.AsyncClient(
timeout=DEFAULT_TIMEOUT,
limits=DEFAULT_CONNECTION_LIMITS,
follow_redirects=True,
)
4 changes: 4 additions & 0 deletions aidial_adapter_openai/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from openai import AsyncAzureOpenAI, AsyncOpenAI, Timeout
from pydantic import BaseModel

from aidial_adapter_openai.utils.http_client import get_http_client


class OpenAIParams(TypedDict, total=False):
api_key: str
Expand All @@ -34,6 +36,7 @@ def get_client(self, params: OpenAIParams) -> AsyncAzureOpenAI:
azure_ad_token=params.get("azure_ad_token"),
api_version=params.get("api_version"),
timeout=params.get("timeout"),
http_client=get_http_client(),
)


Expand All @@ -45,6 +48,7 @@ def get_client(self, params: OpenAIParams) -> AsyncOpenAI:
base_url=self.base_url,
api_key=params.get("api_key"),
timeout=params.get("timeout"),
http_client=get_http_client(),
)


Expand Down
13 changes: 11 additions & 2 deletions aidial_adapter_openai/utils/reflection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import functools
import inspect
from typing import Any, Callable, Coroutine, TypeVar

from aidial_sdk.exceptions import HTTPException as DialException


@functools.lru_cache(maxsize=64)
def _inspect_signature(
func: Callable[..., Coroutine[Any, Any, Any]]
) -> inspect.Signature:
return inspect.signature(func)


T = TypeVar("T")


Expand All @@ -12,7 +21,7 @@ async def call_with_extra_body(
if has_kwargs_argument(func):
return await func(**arg)

expected_args = set(inspect.signature(func).parameters.keys())
expected_args = set(_inspect_signature(func).parameters.keys())
actual_args = set(arg.keys())

extra_args = actual_args - expected_args
Expand All @@ -37,7 +46,7 @@ def has_kwargs_argument(func: Callable[..., Coroutine[Any, Any, Any]]) -> bool:
"""
Determines if the given function accepts a variable keyword argument (**kwargs).
"""
signature = inspect.signature(func)
signature = _inspect_signature(func)
for param in signature.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD:
return True
Expand Down
13 changes: 13 additions & 0 deletions aidial_adapter_openai/utils/streaming.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from time import time
from typing import Any, AsyncIterator, Callable, Optional, TypeVar
from uuid import uuid4
Expand All @@ -6,6 +7,7 @@
from aidial_sdk.utils.merge_chunks import merge
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import APIError
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from aidial_adapter_openai.utils.env import get_env_bool
from aidial_adapter_openai.utils.log_config import logger
Expand Down Expand Up @@ -196,3 +198,14 @@ async def map_stream(
new_item = func(item)
if new_item is not None:
yield new_item


def debug_print(title: str, chunk: dict) -> None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"{title}: {chunk}")


def chunk_to_dict(chunk: ChatCompletionChunk) -> dict:
dict = chunk.to_dict()
debug_print("chunk", dict)
return dict
Loading

0 comments on commit bb5403e

Please sign in to comment.