Skip to content

Commit

Permalink
feat: add legacy completions API support (#124)
Browse files Browse the repository at this point in the history
* Add new type of parser

* add completions endpoint

* -add prompt tepmlate to legacy completions
-move env variables to distinct file

* hotfix

* another hotfix

* pr fixes

* refactor due to pr comment

* another refactor due to PR comments

* another minor refactor

* Fix model field ingestion  for embeddings endpoint
  • Loading branch information
roman-romanov-o authored Jun 26, 2024
1 parent bb5403e commit 29c1684
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 49 deletions.
70 changes: 35 additions & 35 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
import os
from contextlib import asynccontextmanager
from typing import Awaitable, Dict, TypeVar
from typing import Awaitable, TypeVar

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.telemetry.init import init_telemetry
Expand All @@ -11,12 +9,23 @@
from fastapi.responses import JSONResponse, Response
from openai import APIConnectionError, APIStatusError, APITimeoutError

from aidial_adapter_openai.completions import chat_completion as completion
from aidial_adapter_openai.dalle3 import (
chat_completion as dalle3_chat_completion,
)
from aidial_adapter_openai.databricks import (
chat_completion as databricks_chat_completion,
)
from aidial_adapter_openai.env import (
API_VERSIONS_MAPPING,
DALLE3_AZURE_API_VERSION,
DALLE3_DEPLOYMENTS,
DATABRICKS_DEPLOYMENTS,
GPT4_VISION_DEPLOYMENTS,
GPT4O_DEPLOYMENTS,
MISTRAL_DEPLOYMENTS,
MODEL_ALIASES,
)
from aidial_adapter_openai.gpt import gpt_chat_completion
from aidial_adapter_openai.gpt4_multi_modal.chat_completion import (
gpt4_vision_chat_completion,
Expand All @@ -29,9 +38,9 @@
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.parsers import (
completions_parser,
embeddings_parser,
parse_body,
parse_deployment_list,
)
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.storage import create_file_storage
Expand All @@ -51,25 +60,6 @@ async def lifespan(app: FastAPI):
init_telemetry(app, TelemetryConfig())
configure_loggers()

model_aliases: Dict[str, str] = json.loads(os.getenv("MODEL_ALIASES", "{}"))
dalle3_deployments = parse_deployment_list(
os.getenv("DALLE3_DEPLOYMENTS") or ""
)
gpt4_vision_deployments = parse_deployment_list(
os.getenv("GPT4_VISION_DEPLOYMENTS") or ""
)
mistral_deployments = parse_deployment_list(
os.getenv("MISTRAL_DEPLOYMENTS") or ""
)
databricks_deployments = parse_deployment_list(
os.getenv("DATABRICKS_DEPLOYMENTS") or ""
)
gpt4o_deployments = parse_deployment_list(os.getenv("GPT4O_DEPLOYMENTS") or "")
api_versions_mapping: Dict[str, str] = json.loads(
os.getenv("API_VERSIONS_MAPPING", "{}")
)
dalle3_azure_api_version = os.getenv("DALLE3_AZURE_API_VERSION", "2024-02-01")

T = TypeVar("T")


Expand Down Expand Up @@ -101,7 +91,7 @@ async def handle_exceptions(call: Awaitable[T]) -> T | Response:

def get_api_version(request: Request):
api_version = request.query_params.get("api-version", "")
api_version = api_versions_mapping.get(api_version, api_version)
api_version = API_VERSIONS_MAPPING.get(api_version, api_version)

if api_version == "":
raise DialException(
Expand All @@ -125,38 +115,48 @@ async def chat_completion(deployment_id: str, request: Request):
# Azure and non-Azure deployments.
# Therefore, we provide the "model" field for all deployments here.
# The same goes for /embeddings endpoint.
data["model"] = deployment_id
data["model"] = data.get("model") or deployment_id

is_stream = data.get("stream", False)

creds = await get_credentials(request)
api_version = get_api_version(request)

upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]

if deployment_id in dalle3_deployments:
if completions_endpoint := completions_parser.parse(upstream_endpoint):
return await handle_exceptions(
completion(
data,
completions_endpoint,
creds,
api_version,
deployment_id,
)
)

if deployment_id in DALLE3_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
return await dalle3_chat_completion(
data,
upstream_endpoint,
creds,
is_stream,
storage,
dalle3_azure_api_version,
DALLE3_AZURE_API_VERSION,
)

if deployment_id in mistral_deployments:
if deployment_id in MISTRAL_DEPLOYMENTS:
return await handle_exceptions(
mistral_chat_completion(data, upstream_endpoint, creds)
)

if deployment_id in databricks_deployments:
if deployment_id in DATABRICKS_DEPLOYMENTS:
return await handle_exceptions(
databricks_chat_completion(data, upstream_endpoint, creds)
)

api_version = get_api_version(request)

if deployment_id in gpt4_vision_deployments:
if deployment_id in GPT4_VISION_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
return await gpt4_vision_chat_completion(
data,
Expand All @@ -168,10 +168,10 @@ async def chat_completion(deployment_id: str, request: Request):
api_version,
)

openai_model_name = model_aliases.get(deployment_id, deployment_id)
openai_model_name = MODEL_ALIASES.get(deployment_id, deployment_id)
tokenizer = Tokenizer(model=openai_model_name)

if deployment_id in gpt4o_deployments:
if deployment_id in GPT4O_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
return await handle_exceptions(
gpt4o_chat_completion(
Expand Down Expand Up @@ -203,7 +203,7 @@ async def embedding(deployment_id: str, request: Request):
data = await parse_body(request)

# See note for /chat/completions endpoint
data["model"] = deployment_id
data["model"] = data.get("model") or deployment_id

creds = await get_credentials(request)
api_version = get_api_version(request)
Expand Down
84 changes: 84 additions & 0 deletions aidial_adapter_openai/completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Any, Dict

from aidial_sdk.exceptions import HTTPException as DialException
from fastapi.responses import JSONResponse, StreamingResponse
from openai import AsyncStream
from openai.types import Completion

from aidial_adapter_openai.env import COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES
from aidial_adapter_openai.utils.auth import OpenAICreds
from aidial_adapter_openai.utils.parsers import (
AzureOpenAIEndpoint,
OpenAIEndpoint,
)
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.streaming import build_chunk, map_stream


def convert_to_chat_completions_response(
chunk: Completion, is_stream: bool
) -> Dict[str, Any]:
return build_chunk(
id=chunk.id,
finish_reason=chunk.choices[0].finish_reason,
delta=chunk.choices[0].text,
created=str(chunk.created),
is_stream=is_stream,
usage=chunk.usage.to_dict() if chunk.usage else None,
)


async def chat_completion(
data: Dict[str, Any],
endpoint: OpenAIEndpoint | AzureOpenAIEndpoint,
creds: OpenAICreds,
api_version: str,
deployment_id: str,
) -> Any:

if data.get("n", 1) > 1: # type: ignore
raise DialException(
status_code=422,
message="The deployment doesn't support n > 1",
type="invalid_request_error",
)

client = endpoint.get_client({**creds, "api_version": api_version})

messages = data.get("messages", [])
if len(messages) == 0:
raise DialException(
status_code=422,
message="The request doesn't contain any messages",
type="invalid_request_error",
)

prompt = messages[-1].get("content") or ""

if (
template := COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get(deployment_id)
) is not None:
prompt = template.format(prompt=prompt)

del data["messages"]
response = await call_with_extra_body(
client.completions.create,
{"prompt": prompt, **data},
)

if isinstance(response, AsyncStream):
return StreamingResponse(
to_openai_sse_stream(
map_stream(
lambda item: convert_to_chat_completions_response(
item, is_stream=True
),
response,
)
)
)
else:
return JSONResponse(
convert_to_chat_completions_response(response, is_stream=False)
)
27 changes: 27 additions & 0 deletions aidial_adapter_openai/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import json
import os
from typing import Dict

from aidial_adapter_openai.utils.parsers import parse_deployment_list

MODEL_ALIASES: Dict[str, str] = json.loads(os.getenv("MODEL_ALIASES", "{}"))
DALLE3_DEPLOYMENTS = parse_deployment_list(
os.getenv("DALLE3_DEPLOYMENTS") or ""
)
GPT4_VISION_DEPLOYMENTS = parse_deployment_list(
os.getenv("GPT4_VISION_DEPLOYMENTS") or ""
)
MISTRAL_DEPLOYMENTS = parse_deployment_list(
os.getenv("MISTRAL_DEPLOYMENTS") or ""
)
DATABRICKS_DEPLOYMENTS = parse_deployment_list(
os.getenv("DATABRICKS_DEPLOYMENTS") or ""
)
GPT4O_DEPLOYMENTS = parse_deployment_list(os.getenv("GPT4O_DEPLOYMENTS") or "")
API_VERSIONS_MAPPING: Dict[str, str] = json.loads(
os.getenv("API_VERSIONS_MAPPING", "{}")
)
COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = json.loads(
os.getenv("COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES") or "{}"
)
DALLE3_AZURE_API_VERSION = os.getenv("DALLE3_AZURE_API_VERSION", "2024-02-01")
42 changes: 28 additions & 14 deletions aidial_adapter_openai/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,46 @@ def get_client(self, params: OpenAIParams) -> AsyncOpenAI:
)


def _parse_endpoint(
name, endpoint
) -> AzureOpenAIEndpoint | OpenAIEndpoint | None:
if azure_match := re.search(
f"(.+?)/openai/deployments/(.+?)/{name}", endpoint
):
return AzureOpenAIEndpoint(
azure_endpoint=azure_match[1],
azure_deployment=azure_match[2],
)
elif openai_match := re.search(f"(.+?)/{name}", endpoint):
return OpenAIEndpoint(base_url=openai_match[1])
else:
return None


class EndpointParser(BaseModel):
name: str

def parse(self, endpoint: str) -> AzureOpenAIEndpoint | OpenAIEndpoint:
match = re.search(
f"(.+?)/openai/deployments/(.+?)/{self.name}", endpoint
if result := _parse_endpoint(self.name, endpoint):
return result
raise DialException(
"Invalid upstream endpoint format", 400, "invalid_request_error"
)

if match:
return AzureOpenAIEndpoint(
azure_endpoint=match[1],
azure_deployment=match[2],
)

match = re.search(f"(.+?)/{self.name}", endpoint)
class CompletionsParser(BaseModel):
def parse(
self, endpoint: str
) -> AzureOpenAIEndpoint | OpenAIEndpoint | None:
if "/chat/completions" in endpoint:
return None

if match:
return OpenAIEndpoint(base_url=match[1])

raise DialException(
"Invalid upstream endpoint format", 400, "invalid_request_error"
)
return _parse_endpoint("completions", endpoint)


chat_completions_parser = EndpointParser(name="chat/completions")
embeddings_parser = EndpointParser(name="embeddings")
completions_parser = CompletionsParser()


async def parse_body(
Expand Down
Loading

0 comments on commit 29c1684

Please sign in to comment.