Skip to content

Commit 63c39af

Browse files
author
Roman Romanov
committed
Replace usage of env variables with application config in app state
1 parent 2378606 commit 63c39af

File tree

5 files changed

+163
-72
lines changed

5 files changed

+163
-72
lines changed

aidial_adapter_openai/app.py

Lines changed: 107 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from contextlib import asynccontextmanager
2+
from typing import Annotated
23

34
import pydantic
45
from aidial_sdk._errors import pydantic_validation_exception_handler
56
from aidial_sdk.exceptions import HTTPException as DialException
67
from aidial_sdk.exceptions import InvalidRequestError
78
from aidial_sdk.telemetry.init import init_telemetry
89
from aidial_sdk.telemetry.types import TelemetryConfig
9-
from fastapi import FastAPI, Request
10+
from fastapi import Depends, FastAPI, Request
1011
from fastapi.responses import Response
1112
from openai import (
1213
APIConnectionError,
@@ -16,6 +17,7 @@
1617
OpenAIError,
1718
)
1819

20+
from aidial_adapter_openai.app_config import ApplicationConfig
1921
from aidial_adapter_openai.completions import chat_completion as completion
2022
from aidial_adapter_openai.dalle3 import (
2123
chat_completion as dalle3_chat_completion,
@@ -30,17 +32,6 @@
3032
from aidial_adapter_openai.embeddings.openai import (
3133
embeddings as openai_embeddings,
3234
)
33-
from aidial_adapter_openai.env import (
34-
API_VERSIONS_MAPPING,
35-
AZURE_AI_VISION_DEPLOYMENTS,
36-
DALLE3_AZURE_API_VERSION,
37-
DALLE3_DEPLOYMENTS,
38-
DATABRICKS_DEPLOYMENTS,
39-
GPT4_VISION_DEPLOYMENTS,
40-
MISTRAL_DEPLOYMENTS,
41-
MODEL_ALIASES,
42-
NON_STREAMING_DEPLOYMENTS,
43-
)
4435
from aidial_adapter_openai.gpt import gpt_chat_completion
4536
from aidial_adapter_openai.gpt4_multi_modal.chat_completion import (
4637
gpt4_vision_chat_completion,
@@ -51,7 +42,10 @@
5142
)
5243
from aidial_adapter_openai.utils.auth import get_credentials
5344
from aidial_adapter_openai.utils.http_client import get_http_client
54-
from aidial_adapter_openai.utils.image_tokenizer import get_image_tokenizer
45+
from aidial_adapter_openai.utils.image_tokenizer import (
46+
ImageTokenizer,
47+
get_image_tokenizer,
48+
)
5549
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
5650
from aidial_adapter_openai.utils.parsers import completions_parser, parse_body
5751
from aidial_adapter_openai.utils.streaming import create_server_response
@@ -68,43 +62,88 @@ async def lifespan(app: FastAPI):
6862
await get_http_client().aclose()
6963

7064

71-
app = FastAPI(lifespan=lifespan)
65+
def create_app(
66+
app_config: ApplicationConfig | None = None,
67+
to_init_telemetry: bool = True,
68+
to_configure_loggers: bool = True,
69+
) -> FastAPI:
70+
app = FastAPI(lifespan=lifespan)
71+
72+
if app_config is None:
73+
app_config = ApplicationConfig.from_env()
74+
75+
app.state.app_config = app_config
76+
77+
if to_init_telemetry:
78+
init_telemetry(app, TelemetryConfig())
7279

80+
if to_configure_loggers:
81+
configure_loggers()
7382

74-
init_telemetry(app, TelemetryConfig())
75-
configure_loggers()
83+
return app
7684

7785

78-
def get_api_version(request: Request):
86+
def get_app_config(request: Request) -> ApplicationConfig:
87+
return request.app.state.app_config
88+
89+
90+
def get_api_version(request: Request) -> str:
7991
api_version = request.query_params.get("api-version", "")
80-
api_version = API_VERSIONS_MAPPING.get(api_version, api_version)
92+
app_config = get_app_config(request)
93+
api_version = app_config.API_VERSIONS_MAPPING.get(api_version, api_version)
8194

8295
if api_version == "":
8396
raise InvalidRequestError("api-version is a required query parameter")
8497

8598
return api_version
8699

87100

101+
def _get_image_tokenizer(
102+
deployment_id: str, app_config: ApplicationConfig
103+
) -> ImageTokenizer:
104+
image_tokenizer = get_image_tokenizer(deployment_id, app_config)
105+
if not image_tokenizer:
106+
raise RuntimeError(
107+
f"No image tokenizer found for deployment {deployment_id}"
108+
)
109+
return image_tokenizer
110+
111+
112+
app = create_app()
113+
114+
88115
@app.post("/openai/deployments/{deployment_id:path}/chat/completions")
89-
async def chat_completion(deployment_id: str, request: Request):
116+
async def chat_completion(
117+
deployment_id: str,
118+
request: Request,
119+
app_config: Annotated[ApplicationConfig, Depends(get_app_config)],
120+
):
90121

91122
data = await parse_body(request)
92123

93124
is_stream = bool(data.get("stream"))
94125

95-
emulate_streaming = deployment_id in NON_STREAMING_DEPLOYMENTS and is_stream
126+
emulate_streaming = (
127+
deployment_id in app_config.NON_STREAMING_DEPLOYMENTS and is_stream
128+
)
96129

97130
if emulate_streaming:
98131
data["stream"] = False
99132

100133
return create_server_response(
101134
emulate_streaming,
102-
await call_chat_completion(deployment_id, data, is_stream, request),
135+
await call_chat_completion(
136+
deployment_id, data, is_stream, request, app_config
137+
),
103138
)
104139

105140

106141
async def call_chat_completion(
107-
deployment_id: str, data: dict, is_stream: bool, request: Request
142+
deployment_id: str,
143+
data: dict,
144+
is_stream: bool,
145+
request: Request,
146+
app_config: ApplicationConfig,
108147
):
109148

110149
# Azure OpenAI deployments ignore "model" request field,
@@ -129,56 +168,62 @@ async def call_chat_completion(
129168
creds,
130169
api_version,
131170
deployment_id,
171+
app_config,
132172
)
133-
134-
if deployment_id in DALLE3_DEPLOYMENTS:
173+
if deployment_id in app_config.DALLE3_DEPLOYMENTS:
135174
storage = create_file_storage("images", request.headers)
136175
return await dalle3_chat_completion(
137176
data,
138177
upstream_endpoint,
139178
creds,
140179
is_stream,
141180
storage,
142-
DALLE3_AZURE_API_VERSION,
181+
app_config.DALLE3_AZURE_API_VERSION,
143182
)
144183

145-
if deployment_id in MISTRAL_DEPLOYMENTS:
184+
if deployment_id in app_config.MISTRAL_DEPLOYMENTS:
146185
return await mistral_chat_completion(data, upstream_endpoint, creds)
147186

148-
if deployment_id in DATABRICKS_DEPLOYMENTS:
187+
if deployment_id in app_config.DATABRICKS_DEPLOYMENTS:
149188
return await databricks_chat_completion(data, upstream_endpoint, creds)
150189

151-
text_tokenizer_model = MODEL_ALIASES.get(deployment_id, deployment_id)
190+
text_tokenizer_model = app_config.MODEL_ALIASES.get(
191+
deployment_id, deployment_id
192+
)
152193

153-
if image_tokenizer := get_image_tokenizer(deployment_id):
154-
storage = create_file_storage("images", request.headers)
194+
if deployment_id in app_config.GPT4_VISION_DEPLOYMENTS:
195+
tokenizer = MultiModalTokenizer(
196+
"gpt-4", _get_image_tokenizer(deployment_id, app_config)
197+
)
198+
return await gpt4_vision_chat_completion(
199+
data,
200+
deployment_id,
201+
upstream_endpoint,
202+
creds,
203+
is_stream,
204+
create_file_storage("images", request.headers),
205+
api_version,
206+
tokenizer,
207+
)
155208

156-
if deployment_id in GPT4_VISION_DEPLOYMENTS:
157-
tokenizer = MultiModalTokenizer("gpt-4", image_tokenizer)
158-
return await gpt4_vision_chat_completion(
159-
data,
160-
deployment_id,
161-
upstream_endpoint,
162-
creds,
163-
is_stream,
164-
storage,
165-
api_version,
166-
tokenizer,
167-
)
168-
else:
169-
tokenizer = MultiModalTokenizer(
170-
text_tokenizer_model, image_tokenizer
171-
)
172-
return await gpt4o_chat_completion(
173-
data,
174-
deployment_id,
175-
upstream_endpoint,
176-
creds,
177-
is_stream,
178-
storage,
179-
api_version,
180-
tokenizer,
181-
)
209+
if deployment_id in (
210+
*app_config.GPT4O_DEPLOYMENTS,
211+
*app_config.GPT4O_MINI_DEPLOYMENTS,
212+
):
213+
tokenizer = MultiModalTokenizer(
214+
text_tokenizer_model,
215+
_get_image_tokenizer(deployment_id, app_config),
216+
)
217+
return await gpt4o_chat_completion(
218+
data,
219+
deployment_id,
220+
upstream_endpoint,
221+
creds,
222+
is_stream,
223+
create_file_storage("images", request.headers),
224+
api_version,
225+
tokenizer,
226+
)
182227

183228
tokenizer = PlainTextTokenizer(model=text_tokenizer_model)
184229
return await gpt_chat_completion(
@@ -192,7 +237,11 @@ async def call_chat_completion(
192237

193238

194239
@app.post("/openai/deployments/{deployment_id:path}/embeddings")
195-
async def embedding(deployment_id: str, request: Request):
240+
async def embedding(
241+
deployment_id: str,
242+
request: Request,
243+
app_config: Annotated[ApplicationConfig, Depends(get_app_config)],
244+
):
196245
data = await parse_body(request)
197246

198247
# See note for /chat/completions endpoint
@@ -202,7 +251,7 @@ async def embedding(deployment_id: str, request: Request):
202251
api_version = get_api_version(request)
203252
upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]
204253

205-
if deployment_id in AZURE_AI_VISION_DEPLOYMENTS:
254+
if deployment_id in app_config.AZURE_AI_VISION_DEPLOYMENTS:
206255
storage = create_file_storage("images", request.headers)
207256
return await azure_ai_vision_embeddings(
208257
creds, deployment_id, upstream_endpoint, storage, data

aidial_adapter_openai/app_config.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Dict, List
2+
3+
from pydantic import BaseModel, Field
4+
5+
import aidial_adapter_openai.env as env
6+
7+
8+
class ApplicationConfig(BaseModel):
9+
MODEL_ALIASES: Dict[str, str] = Field(default_factory=dict)
10+
DALLE3_DEPLOYMENTS: List[str] = Field(default_factory=list)
11+
GPT4_VISION_DEPLOYMENTS: List[str] = Field(default_factory=list)
12+
MISTRAL_DEPLOYMENTS: List[str] = Field(default_factory=list)
13+
DATABRICKS_DEPLOYMENTS: List[str] = Field(default_factory=list)
14+
GPT4O_DEPLOYMENTS: List[str] = Field(default_factory=list)
15+
GPT4O_MINI_DEPLOYMENTS: List[str] = Field(default_factory=list)
16+
AZURE_AI_VISION_DEPLOYMENTS: List[str] = Field(default_factory=list)
17+
API_VERSIONS_MAPPING: Dict[str, str] = Field(default_factory=dict)
18+
COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = Field(
19+
default_factory=dict
20+
)
21+
DALLE3_AZURE_API_VERSION: str = Field(default="2024-02-01")
22+
NON_STREAMING_DEPLOYMENTS: List[str] = Field(default_factory=list)
23+
24+
@classmethod
25+
def from_env(cls) -> "ApplicationConfig":
26+
return cls(
27+
MODEL_ALIASES=env.MODEL_ALIASES,
28+
DALLE3_DEPLOYMENTS=env.DALLE3_DEPLOYMENTS,
29+
GPT4_VISION_DEPLOYMENTS=env.GPT4_VISION_DEPLOYMENTS,
30+
MISTRAL_DEPLOYMENTS=env.MISTRAL_DEPLOYMENTS,
31+
DATABRICKS_DEPLOYMENTS=env.DATABRICKS_DEPLOYMENTS,
32+
GPT4O_DEPLOYMENTS=env.GPT4O_DEPLOYMENTS,
33+
GPT4O_MINI_DEPLOYMENTS=env.GPT4O_MINI_DEPLOYMENTS,
34+
AZURE_AI_VISION_DEPLOYMENTS=env.AZURE_AI_VISION_DEPLOYMENTS,
35+
API_VERSIONS_MAPPING=env.API_VERSIONS_MAPPING,
36+
COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES=env.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES,
37+
DALLE3_AZURE_API_VERSION=env.DALLE3_AZURE_API_VERSION,
38+
NON_STREAMING_DEPLOYMENTS=env.NON_STREAMING_DEPLOYMENTS,
39+
)

aidial_adapter_openai/completions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from openai import AsyncStream
55
from openai.types import Completion
66

7-
from aidial_adapter_openai.env import COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES
7+
from aidial_adapter_openai.app_config import ApplicationConfig
88
from aidial_adapter_openai.utils.auth import OpenAICreds
99
from aidial_adapter_openai.utils.parsers import (
1010
AzureOpenAIEndpoint,
@@ -46,6 +46,7 @@ async def chat_completion(
4646
creds: OpenAICreds,
4747
api_version: str,
4848
deployment_id: str,
49+
app_config: ApplicationConfig,
4950
):
5051

5152
if data.get("n") or 1 > 1:
@@ -60,7 +61,9 @@ async def chat_completion(
6061
prompt = messages[-1].get("content") or ""
6162

6263
if (
63-
template := COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get(deployment_id)
64+
template := app_config.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES.get(
65+
deployment_id
66+
)
6467
) is not None:
6568
prompt = template.format(prompt=prompt)
6669

aidial_adapter_openai/constant.py

Whitespace-only changes.

aidial_adapter_openai/utils/image_tokenizer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88

99
from pydantic import BaseModel
1010

11-
from aidial_adapter_openai.env import (
12-
GPT4_VISION_DEPLOYMENTS,
13-
GPT4O_DEPLOYMENTS,
14-
GPT4O_MINI_DEPLOYMENTS,
15-
)
11+
from aidial_adapter_openai.app_config import ApplicationConfig
1612
from aidial_adapter_openai.utils.image import ImageDetail, resolve_detail_level
1713

1814

@@ -58,14 +54,18 @@ def _compute_high_detail_tokens(self, width: int, height: int) -> int:
5854
low_detail_tokens=2833, tokens_per_tile=5667
5955
)
6056

61-
_TOKENIZERS: List[Tuple[ImageTokenizer, List[str]]] = [
62-
(GPT4O_IMAGE_TOKENIZER, GPT4O_DEPLOYMENTS),
63-
(GPT4O_MINI_IMAGE_TOKENIZER, GPT4O_MINI_DEPLOYMENTS),
64-
(GPT4_VISION_IMAGE_TOKENIZER, GPT4_VISION_DEPLOYMENTS),
65-
]
66-
6757

68-
def get_image_tokenizer(deployment_id: str) -> ImageTokenizer | None:
58+
def get_image_tokenizer(
59+
deployment_id: str, app_config: ApplicationConfig
60+
) -> ImageTokenizer | None:
61+
_TOKENIZERS: List[Tuple[ImageTokenizer, List[str]]] = [
62+
(GPT4O_IMAGE_TOKENIZER, app_config.GPT4O_DEPLOYMENTS),
63+
(GPT4O_MINI_IMAGE_TOKENIZER, app_config.GPT4O_MINI_DEPLOYMENTS),
64+
(
65+
GPT4_VISION_IMAGE_TOKENIZER,
66+
app_config.GPT4_VISION_DEPLOYMENTS,
67+
),
68+
]
6969
for tokenizer, ids in _TOKENIZERS:
7070
if deployment_id in ids:
7171
return tokenizer

0 commit comments

Comments
 (0)