Skip to content

Commit

Permalink
Replace usage of env variables with application config in app state
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-romanov-o committed Nov 26, 2024
1 parent 2378606 commit 63c39af
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 72 deletions.
165 changes: 107 additions & 58 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from contextlib import asynccontextmanager
from typing import Annotated

import pydantic
from aidial_sdk._errors import pydantic_validation_exception_handler
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.telemetry.init import init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import FastAPI, Request
from fastapi import Depends, FastAPI, Request
from fastapi.responses import Response
from openai import (
APIConnectionError,
Expand All @@ -16,6 +17,7 @@
OpenAIError,
)

from aidial_adapter_openai.app_config import ApplicationConfig
from aidial_adapter_openai.completions import chat_completion as completion
from aidial_adapter_openai.dalle3 import (
chat_completion as dalle3_chat_completion,
Expand All @@ -30,17 +32,6 @@
from aidial_adapter_openai.embeddings.openai import (
embeddings as openai_embeddings,
)
from aidial_adapter_openai.env import (
API_VERSIONS_MAPPING,
AZURE_AI_VISION_DEPLOYMENTS,
DALLE3_AZURE_API_VERSION,
DALLE3_DEPLOYMENTS,
DATABRICKS_DEPLOYMENTS,
GPT4_VISION_DEPLOYMENTS,
MISTRAL_DEPLOYMENTS,
MODEL_ALIASES,
NON_STREAMING_DEPLOYMENTS,
)
from aidial_adapter_openai.gpt import gpt_chat_completion
from aidial_adapter_openai.gpt4_multi_modal.chat_completion import (
gpt4_vision_chat_completion,
Expand All @@ -51,7 +42,10 @@
)
from aidial_adapter_openai.utils.auth import get_credentials
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.image_tokenizer import get_image_tokenizer
from aidial_adapter_openai.utils.image_tokenizer import (
ImageTokenizer,
get_image_tokenizer,
)
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.parsers import completions_parser, parse_body
from aidial_adapter_openai.utils.streaming import create_server_response
Expand All @@ -68,43 +62,88 @@ async def lifespan(app: FastAPI):
await get_http_client().aclose()


app = FastAPI(lifespan=lifespan)
def create_app(
app_config: ApplicationConfig | None = None,
to_init_telemetry: bool = True,
to_configure_loggers: bool = True,
) -> FastAPI:
app = FastAPI(lifespan=lifespan)

if app_config is None:
app_config = ApplicationConfig.from_env()

app.state.app_config = app_config

if to_init_telemetry:
init_telemetry(app, TelemetryConfig())

if to_configure_loggers:
configure_loggers()

init_telemetry(app, TelemetryConfig())
configure_loggers()
return app


def get_api_version(request: Request):
def get_app_config(request: Request) -> ApplicationConfig:
return request.app.state.app_config


def get_api_version(request: Request) -> str:
api_version = request.query_params.get("api-version", "")
api_version = API_VERSIONS_MAPPING.get(api_version, api_version)
app_config = get_app_config(request)
api_version = app_config.API_VERSIONS_MAPPING.get(api_version, api_version)

if api_version == "":
raise InvalidRequestError("api-version is a required query parameter")

return api_version


def _get_image_tokenizer(
deployment_id: str, app_config: ApplicationConfig
) -> ImageTokenizer:
image_tokenizer = get_image_tokenizer(deployment_id, app_config)
if not image_tokenizer:
raise RuntimeError(
f"No image tokenizer found for deployment {deployment_id}"
)
return image_tokenizer


app = create_app()


@app.post("/openai/deployments/{deployment_id:path}/chat/completions")
async def chat_completion(deployment_id: str, request: Request):
async def chat_completion(
deployment_id: str,
request: Request,
app_config: Annotated[ApplicationConfig, Depends(get_app_config)],
):

data = await parse_body(request)

is_stream = bool(data.get("stream"))

emulate_streaming = deployment_id in NON_STREAMING_DEPLOYMENTS and is_stream
emulate_streaming = (
deployment_id in app_config.NON_STREAMING_DEPLOYMENTS and is_stream
)

if emulate_streaming:
data["stream"] = False

return create_server_response(
emulate_streaming,
await call_chat_completion(deployment_id, data, is_stream, request),
await call_chat_completion(
deployment_id, data, is_stream, request, app_config
),
)


async def call_chat_completion(
deployment_id: str, data: dict, is_stream: bool, request: Request
deployment_id: str,
data: dict,
is_stream: bool,
request: Request,
app_config: ApplicationConfig,
):

# Azure OpenAI deployments ignore "model" request field,
Expand All @@ -129,56 +168,62 @@ async def call_chat_completion(
creds,
api_version,
deployment_id,
app_config,
)

if deployment_id in DALLE3_DEPLOYMENTS:
if deployment_id in app_config.DALLE3_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
return await dalle3_chat_completion(
data,
upstream_endpoint,
creds,
is_stream,
storage,
DALLE3_AZURE_API_VERSION,
app_config.DALLE3_AZURE_API_VERSION,
)

if deployment_id in MISTRAL_DEPLOYMENTS:
if deployment_id in app_config.MISTRAL_DEPLOYMENTS:
return await mistral_chat_completion(data, upstream_endpoint, creds)

if deployment_id in DATABRICKS_DEPLOYMENTS:
if deployment_id in app_config.DATABRICKS_DEPLOYMENTS:
return await databricks_chat_completion(data, upstream_endpoint, creds)

text_tokenizer_model = MODEL_ALIASES.get(deployment_id, deployment_id)
text_tokenizer_model = app_config.MODEL_ALIASES.get(
deployment_id, deployment_id
)

if image_tokenizer := get_image_tokenizer(deployment_id):
storage = create_file_storage("images", request.headers)
if deployment_id in app_config.GPT4_VISION_DEPLOYMENTS:
tokenizer = MultiModalTokenizer(
"gpt-4", _get_image_tokenizer(deployment_id, app_config)
)
return await gpt4_vision_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
create_file_storage("images", request.headers),
api_version,
tokenizer,
)

if deployment_id in GPT4_VISION_DEPLOYMENTS:
tokenizer = MultiModalTokenizer("gpt-4", image_tokenizer)
return await gpt4_vision_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
storage,
api_version,
tokenizer,
)
else:
tokenizer = MultiModalTokenizer(
text_tokenizer_model, image_tokenizer
)
return await gpt4o_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
storage,
api_version,
tokenizer,
)
if deployment_id in (
*app_config.GPT4O_DEPLOYMENTS,
*app_config.GPT4O_MINI_DEPLOYMENTS,
):
tokenizer = MultiModalTokenizer(
text_tokenizer_model,
_get_image_tokenizer(deployment_id, app_config),
)
return await gpt4o_chat_completion(
data,
deployment_id,
upstream_endpoint,
creds,
is_stream,
create_file_storage("images", request.headers),
api_version,
tokenizer,
)

tokenizer = PlainTextTokenizer(model=text_tokenizer_model)
return await gpt_chat_completion(
Expand All @@ -192,7 +237,11 @@ async def call_chat_completion(


@app.post("/openai/deployments/{deployment_id:path}/embeddings")
async def embedding(deployment_id: str, request: Request):
async def embedding(
deployment_id: str,
request: Request,
app_config: Annotated[ApplicationConfig, Depends(get_app_config)],
):
data = await parse_body(request)

# See note for /chat/completions endpoint
Expand All @@ -202,7 +251,7 @@ async def embedding(deployment_id: str, request: Request):
api_version = get_api_version(request)
upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]

if deployment_id in AZURE_AI_VISION_DEPLOYMENTS:
if deployment_id in app_config.AZURE_AI_VISION_DEPLOYMENTS:
storage = create_file_storage("images", request.headers)
return await azure_ai_vision_embeddings(
creds, deployment_id, upstream_endpoint, storage, data
Expand Down
39 changes: 39 additions & 0 deletions aidial_adapter_openai/app_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Dict, List

from pydantic import BaseModel, Field

import aidial_adapter_openai.env as env


class ApplicationConfig(BaseModel):
MODEL_ALIASES: Dict[str, str] = Field(default_factory=dict)
DALLE3_DEPLOYMENTS: List[str] = Field(default_factory=list)
GPT4_VISION_DEPLOYMENTS: List[str] = Field(default_factory=list)
MISTRAL_DEPLOYMENTS: List[str] = Field(default_factory=list)
DATABRICKS_DEPLOYMENTS: List[str] = Field(default_factory=list)
GPT4O_DEPLOYMENTS: List[str] = Field(default_factory=list)
GPT4O_MINI_DEPLOYMENTS: List[str] = Field(default_factory=list)
AZURE_AI_VISION_DEPLOYMENTS: List[str] = Field(default_factory=list)
API_VERSIONS_MAPPING: Dict[str, str] = Field(default_factory=dict)
COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = Field(
default_factory=dict
)
DALLE3_AZURE_API_VERSION: str = Field(default="2024-02-01")
NON_STREAMING_DEPLOYMENTS: List[str] = Field(default_factory=list)

@classmethod
def from_env(cls) -> "ApplicationConfig":
return cls(
MODEL_ALIASES=env.MODEL_ALIASES,
DALLE3_DEPLOYMENTS=env.DALLE3_DEPLOYMENTS,
GPT4_VISION_DEPLOYMENTS=env.GPT4_VISION_DEPLOYMENTS,
MISTRAL_DEPLOYMENTS=env.MISTRAL_DEPLOYMENTS,
DATABRICKS_DEPLOYMENTS=env.DATABRICKS_DEPLOYMENTS,
GPT4O_DEPLOYMENTS=env.GPT4O_DEPLOYMENTS,
GPT4O_MINI_DEPLOYMENTS=env.GPT4O_MINI_DEPLOYMENTS,
AZURE_AI_VISION_DEPLOYMENTS=env.AZURE_AI_VISION_DEPLOYMENTS,
API_VERSIONS_MAPPING=env.API_VERSIONS_MAPPING,
COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES=env.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES,
DALLE3_AZURE_API_VERSION=env.DALLE3_AZURE_API_VERSION,
NON_STREAMING_DEPLOYMENTS=env.NON_STREAMING_DEPLOYMENTS,
)
7 changes: 5 additions & 2 deletions aidial_adapter_openai/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from openai import AsyncStream
from openai.types import Completion

from aidial_adapter_openai.env import COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES
from aidial_adapter_openai.app_config import ApplicationConfig
from aidial_adapter_openai.utils.auth import OpenAICreds
from aidial_adapter_openai.utils.parsers import (
AzureOpenAIEndpoint,
Expand Down Expand Up @@ -46,6 +46,7 @@ async def chat_completion(
creds: OpenAICreds,
api_version: str,
deployment_id: str,
app_config: ApplicationConfig,
):

if data.get("n") or 1 > 1:
Expand All @@ -60,7 +61,9 @@ async def chat_completion(
prompt = messages[-1].get("content") or ""

if (
template := COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get(deployment_id)
template := app_config.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get(
deployment_id
)
) is not None:
prompt = template.format(prompt=prompt)

Expand Down
Empty file removed aidial_adapter_openai/constant.py
Empty file.
24 changes: 12 additions & 12 deletions aidial_adapter_openai/utils/image_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@

from pydantic import BaseModel

from aidial_adapter_openai.env import (
GPT4_VISION_DEPLOYMENTS,
GPT4O_DEPLOYMENTS,
GPT4O_MINI_DEPLOYMENTS,
)
from aidial_adapter_openai.app_config import ApplicationConfig
from aidial_adapter_openai.utils.image import ImageDetail, resolve_detail_level


Expand Down Expand Up @@ -58,14 +54,18 @@ def _compute_high_detail_tokens(self, width: int, height: int) -> int:
low_detail_tokens=2833, tokens_per_tile=5667
)

_TOKENIZERS: List[Tuple[ImageTokenizer, List[str]]] = [
(GPT4O_IMAGE_TOKENIZER, GPT4O_DEPLOYMENTS),
(GPT4O_MINI_IMAGE_TOKENIZER, GPT4O_MINI_DEPLOYMENTS),
(GPT4_VISION_IMAGE_TOKENIZER, GPT4_VISION_DEPLOYMENTS),
]


def get_image_tokenizer(deployment_id: str) -> ImageTokenizer | None:
def get_image_tokenizer(
deployment_id: str, app_config: ApplicationConfig
) -> ImageTokenizer | None:
_TOKENIZERS: List[Tuple[ImageTokenizer, List[str]]] = [
(GPT4O_IMAGE_TOKENIZER, app_config.GPT4O_DEPLOYMENTS),
(GPT4O_MINI_IMAGE_TOKENIZER, app_config.GPT4O_MINI_DEPLOYMENTS),
(
GPT4_VISION_IMAGE_TOKENIZER,
app_config.GPT4_VISION_DEPLOYMENTS,
),
]
for tokenizer, ids in _TOKENIZERS:
if deployment_id in ids:
return tokenizer
Expand Down

0 comments on commit 63c39af

Please sign in to comment.