Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Replace usage of env variables with application config in app state #178

Merged
merged 10 commits into from
Dec 2, 2024
165 changes: 107 additions & 58 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from contextlib import asynccontextmanager
from typing import Annotated

import pydantic
from aidial_sdk._errors import pydantic_validation_exception_handler
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.telemetry.init import init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import FastAPI, Request
from fastapi import Depends, FastAPI, Request
from fastapi.responses import Response
from openai import (
APIConnectionError,
Expand All @@ -16,6 +17,7 @@
OpenAIError,
)

from aidial_adapter_openai.app_config import ApplicationConfig
from aidial_adapter_openai.completions import chat_completion as completion
from aidial_adapter_openai.dalle3 import (
chat_completion as dalle3_chat_completion,
Expand All @@ -30,17 +32,6 @@
from aidial_adapter_openai.embeddings.openai import (
embeddings as openai_embeddings,
)
from aidial_adapter_openai.env import (
API_VERSIONS_MAPPING,
AZURE_AI_VISION_DEPLOYMENTS,
DALLE3_AZURE_API_VERSION,
DALLE3_DEPLOYMENTS,
DATABRICKS_DEPLOYMENTS,
GPT4_VISION_DEPLOYMENTS,
MISTRAL_DEPLOYMENTS,
MODEL_ALIASES,
NON_STREAMING_DEPLOYMENTS,
)
from aidial_adapter_openai.gpt import gpt_chat_completion
from aidial_adapter_openai.gpt4_multi_modal.chat_completion import (
gpt4_vision_chat_completion,
Expand All @@ -51,7 +42,10 @@
)
from aidial_adapter_openai.utils.auth import get_credentials
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.image_tokenizer import get_image_tokenizer
from aidial_adapter_openai.utils.image_tokenizer import (
ImageTokenizer,
get_image_tokenizer,
)
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.parsers import completions_parser, parse_body
from aidial_adapter_openai.utils.streaming import create_server_response
Expand All @@ -68,43 +62,88 @@ async def lifespan(app: FastAPI):
await get_http_client().aclose()


app = FastAPI(lifespan=lifespan)
def create_app(
app_config: ApplicationConfig | None = None,
to_init_telemetry: bool = True,
adubovik marked this conversation as resolved.
Show resolved Hide resolved
to_configure_loggers: bool = True,
) -> FastAPI:
app = FastAPI(lifespan=lifespan)

if app_config is None:
adubovik marked this conversation as resolved.
Show resolved Hide resolved
app_config = ApplicationConfig.from_env()

app.state.app_config = app_config

if to_init_telemetry:
init_telemetry(app, TelemetryConfig())

if to_configure_loggers:
configure_loggers()
adubovik marked this conversation as resolved.
Show resolved Hide resolved

init_telemetry(app, TelemetryConfig())
configure_loggers()
return app


def get_api_version(request: Request):
def get_app_config(request: Request) -> ApplicationConfig:
return request.app.state.app_config


def get_api_version(request: Request) -> str:
api_version = request.query_params.get("api-version", "")
api_version = API_VERSIONS_MAPPING.get(api_version, api_version)
app_config = get_app_config(request)
api_version = app_config.API_VERSIONS_MAPPING.get(api_version, api_version)

if api_version == "":
raise InvalidRequestError("api-version is a required query parameter")

return api_version


def _get_image_tokenizer(
deployment_id: str, app_config: ApplicationConfig
) -> ImageTokenizer:
image_tokenizer = get_image_tokenizer(deployment_id, app_config)
if not image_tokenizer:
raise RuntimeError(
f"No image tokenizer found for deployment {deployment_id}"
)
adubovik marked this conversation as resolved.
Show resolved Hide resolved
return image_tokenizer


app = create_app()


@app.post("/openai/deployments/{deployment_id:path}/chat/completions")
async def chat_completion(deployment_id: str, request: Request):
async def chat_completion(
adubovik marked this conversation as resolved.
Show resolved Hide resolved
deployment_id: str,
request: Request,
app_config: Annotated[ApplicationConfig, Depends(get_app_config)],
):

data = await parse_body(request)

is_stream = bool(data.get("stream"))

emulate_streaming = deployment_id in NON_STREAMING_DEPLOYMENTS and is_stream
emulate_streaming = (
deployment_id in app_config.NON_STREAMING_DEPLOYMENTS and is_stream
)

if emulate_streaming:
data["stream"] = False

return create_server_response(
emulate_streaming,
await call_chat_completion(deployment_id, data, is_stream, request),
await call_chat_completion(
deployment_id, data, is_stream, request, app_config
),
)


async def call_chat_completion(
deployment_id: str, data: dict, is_stream: bool, request: Request
deployment_id: str,
data: dict,
is_stream: bool,
request: Request,
app_config: ApplicationConfig,
):

# Azure OpenAI deployments ignore "model" request field,
Expand All @@ -129,56 +168,62 @@ async def call_chat_completion(
creds,
api_version,
deployment_id,
app_config,
)

if deployment_id in DALLE3_DEPLOYMENTS:
if deployment_id in app_config.DALLE3_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
return await dalle3_chat_completion(
data,
upstream_endpoint,
creds,
is_stream,
storage,
DALLE3_AZURE_API_VERSION,
app_config.DALLE3_AZURE_API_VERSION,
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
)

if deployment_id in MISTRAL_DEPLOYMENTS:
if deployment_id in app_config.MISTRAL_DEPLOYMENTS:
return await mistral_chat_completion(data, upstream_endpoint, creds)

if deployment_id in DATABRICKS_DEPLOYMENTS:
if deployment_id in app_config.DATABRICKS_DEPLOYMENTS:
return await databricks_chat_completion(data, upstream_endpoint, creds)

text_tokenizer_model = MODEL_ALIASES.get(deployment_id, deployment_id)
text_tokenizer_model = app_config.MODEL_ALIASES.get(
deployment_id, deployment_id
)

if image_tokenizer := get_image_tokenizer(deployment_id):
storage = create_file_storage("images", request.headers)
if deployment_id in app_config.GPT4_VISION_DEPLOYMENTS:
tokenizer = MultiModalTokenizer(
"gpt-4", _get_image_tokenizer(deployment_id, app_config)
)
return await gpt4_vision_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
create_file_storage("images", request.headers),
api_version,
tokenizer,
)

if deployment_id in GPT4_VISION_DEPLOYMENTS:
tokenizer = MultiModalTokenizer("gpt-4", image_tokenizer)
return await gpt4_vision_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
storage,
api_version,
tokenizer,
)
else:
tokenizer = MultiModalTokenizer(
text_tokenizer_model, image_tokenizer
)
return await gpt4o_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
storage,
api_version,
tokenizer,
)
if deployment_id in (
*app_config.GPT4O_DEPLOYMENTS,
*app_config.GPT4O_MINI_DEPLOYMENTS,
):
tokenizer = MultiModalTokenizer(
text_tokenizer_model,
_get_image_tokenizer(deployment_id, app_config),
)
return await gpt4o_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
create_file_storage("images", request.headers),
api_version,
tokenizer,
)

tokenizer = PlainTextTokenizer(model=text_tokenizer_model)
return await gpt_chat_completion(
Expand All @@ -192,7 +237,11 @@ async def call_chat_completion(


@app.post("/openai/deployments/{deployment_id:path}/embeddings")
async def embedding(deployment_id: str, request: Request):
async def embedding(
deployment_id: str,
request: Request,
app_config: Annotated[ApplicationConfig, Depends(get_app_config)],
):
data = await parse_body(request)

# See note for /chat/completions endpoint
Expand All @@ -202,7 +251,7 @@ async def embedding(deployment_id: str, request: Request):
api_version = get_api_version(request)
upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]

if deployment_id in AZURE_AI_VISION_DEPLOYMENTS:
if deployment_id in app_config.AZURE_AI_VISION_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
return await azure_ai_vision_embeddings(
creds, deployment_id, upstream_endpoint, storage, data
Expand Down
39 changes: 39 additions & 0 deletions aidial_adapter_openai/app_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Dict, List

from pydantic import BaseModel, Field

import aidial_adapter_openai.env as env


class ApplicationConfig(BaseModel):
MODEL_ALIASES: Dict[str, str] = Field(default_factory=dict)
adubovik marked this conversation as resolved.
Show resolved Hide resolved
DALLE3_DEPLOYMENTS: List[str] = Field(default_factory=list)
GPT4_VISION_DEPLOYMENTS: List[str] = Field(default_factory=list)
MISTRAL_DEPLOYMENTS: List[str] = Field(default_factory=list)
DATABRICKS_DEPLOYMENTS: List[str] = Field(default_factory=list)
GPT4O_DEPLOYMENTS: List[str] = Field(default_factory=list)
GPT4O_MINI_DEPLOYMENTS: List[str] = Field(default_factory=list)
AZURE_AI_VISION_DEPLOYMENTS: List[str] = Field(default_factory=list)
API_VERSIONS_MAPPING: Dict[str, str] = Field(default_factory=dict)
COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = Field(
default_factory=dict
)
DALLE3_AZURE_API_VERSION: str = Field(default="2024-02-01")
NON_STREAMING_DEPLOYMENTS: List[str] = Field(default_factory=list)

@classmethod
def from_env(cls) -> "ApplicationConfig":
return cls(
MODEL_ALIASES=env.MODEL_ALIASES,
adubovik marked this conversation as resolved.
Show resolved Hide resolved
DALLE3_DEPLOYMENTS=env.DALLE3_DEPLOYMENTS,
GPT4_VISION_DEPLOYMENTS=env.GPT4_VISION_DEPLOYMENTS,
MISTRAL_DEPLOYMENTS=env.MISTRAL_DEPLOYMENTS,
DATABRICKS_DEPLOYMENTS=env.DATABRICKS_DEPLOYMENTS,
GPT4O_DEPLOYMENTS=env.GPT4O_DEPLOYMENTS,
GPT4O_MINI_DEPLOYMENTS=env.GPT4O_MINI_DEPLOYMENTS,
AZURE_AI_VISION_DEPLOYMENTS=env.AZURE_AI_VISION_DEPLOYMENTS,
API_VERSIONS_MAPPING=env.API_VERSIONS_MAPPING,
COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES=env.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES,
DALLE3_AZURE_API_VERSION=env.DALLE3_AZURE_API_VERSION,
NON_STREAMING_DEPLOYMENTS=env.NON_STREAMING_DEPLOYMENTS,
)
7 changes: 5 additions & 2 deletions aidial_adapter_openai/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from openai import AsyncStream
from openai.types import Completion

from aidial_adapter_openai.env import COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES
from aidial_adapter_openai.app_config import ApplicationConfig
from aidial_adapter_openai.utils.auth import OpenAICreds
from aidial_adapter_openai.utils.parsers import (
AzureOpenAIEndpoint,
Expand Down Expand Up @@ -46,6 +46,7 @@ async def chat_completion(
creds: OpenAICreds,
api_version: str,
deployment_id: str,
app_config: ApplicationConfig,
):

if data.get("n") or 1 > 1:
Expand All @@ -60,7 +61,9 @@ async def chat_completion(
prompt = messages[-1].get("content") or ""

if (
template := COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get(deployment_id)
template := app_config.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get(
deployment_id
)
) is not None:
prompt = template.format(prompt=prompt)

Expand Down
Empty file.
24 changes: 12 additions & 12 deletions aidial_adapter_openai/utils/image_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@

from pydantic import BaseModel

from aidial_adapter_openai.env import (
GPT4_VISION_DEPLOYMENTS,
GPT4O_DEPLOYMENTS,
GPT4O_MINI_DEPLOYMENTS,
)
from aidial_adapter_openai.app_config import ApplicationConfig
from aidial_adapter_openai.utils.image import ImageDetail, resolve_detail_level


Expand Down Expand Up @@ -58,14 +54,18 @@ def _compute_high_detail_tokens(self, width: int, height: int) -> int:
low_detail_tokens=2833, tokens_per_tile=5667
)

_TOKENIZERS: List[Tuple[ImageTokenizer, List[str]]] = [
(GPT4O_IMAGE_TOKENIZER, GPT4O_DEPLOYMENTS),
(GPT4O_MINI_IMAGE_TOKENIZER, GPT4O_MINI_DEPLOYMENTS),
(GPT4_VISION_IMAGE_TOKENIZER, GPT4_VISION_DEPLOYMENTS),
]


def get_image_tokenizer(deployment_id: str) -> ImageTokenizer | None:
def get_image_tokenizer(
deployment_id: str, app_config: ApplicationConfig
) -> ImageTokenizer | None:
_TOKENIZERS: List[Tuple[ImageTokenizer, List[str]]] = [
adubovik marked this conversation as resolved.
Show resolved Hide resolved
(GPT4O_IMAGE_TOKENIZER, app_config.GPT4O_DEPLOYMENTS),
(GPT4O_MINI_IMAGE_TOKENIZER, app_config.GPT4O_MINI_DEPLOYMENTS),
(
GPT4_VISION_IMAGE_TOKENIZER,
app_config.GPT4_VISION_DEPLOYMENTS,
),
]
for tokenizer, ids in _TOKENIZERS:
if deployment_id in ids:
return tokenizer
Expand Down
Loading