1
1
from contextlib import asynccontextmanager
2
+ from typing import Annotated
2
3
3
4
import pydantic
4
5
from aidial_sdk ._errors import pydantic_validation_exception_handler
5
6
from aidial_sdk .exceptions import HTTPException as DialException
6
7
from aidial_sdk .exceptions import InvalidRequestError
7
8
from aidial_sdk .telemetry .init import init_telemetry
8
9
from aidial_sdk .telemetry .types import TelemetryConfig
9
- from fastapi import FastAPI , Request
10
+ from fastapi import Depends , FastAPI , Request
10
11
from fastapi .responses import Response
11
12
from openai import (
12
13
APIConnectionError ,
16
17
OpenAIError ,
17
18
)
18
19
20
+ from aidial_adapter_openai .app_config import ApplicationConfig
19
21
from aidial_adapter_openai .completions import chat_completion as completion
20
22
from aidial_adapter_openai .dalle3 import (
21
23
chat_completion as dalle3_chat_completion ,
30
32
from aidial_adapter_openai .embeddings .openai import (
31
33
embeddings as openai_embeddings ,
32
34
)
33
- from aidial_adapter_openai .env import (
34
- API_VERSIONS_MAPPING ,
35
- AZURE_AI_VISION_DEPLOYMENTS ,
36
- DALLE3_AZURE_API_VERSION ,
37
- DALLE3_DEPLOYMENTS ,
38
- DATABRICKS_DEPLOYMENTS ,
39
- GPT4_VISION_DEPLOYMENTS ,
40
- MISTRAL_DEPLOYMENTS ,
41
- MODEL_ALIASES ,
42
- NON_STREAMING_DEPLOYMENTS ,
43
- )
44
35
from aidial_adapter_openai .gpt import gpt_chat_completion
45
36
from aidial_adapter_openai .gpt4_multi_modal .chat_completion import (
46
37
gpt4_vision_chat_completion ,
51
42
)
52
43
from aidial_adapter_openai .utils .auth import get_credentials
53
44
from aidial_adapter_openai .utils .http_client import get_http_client
54
- from aidial_adapter_openai .utils .image_tokenizer import get_image_tokenizer
45
+ from aidial_adapter_openai .utils .image_tokenizer import (
46
+ ImageTokenizer ,
47
+ get_image_tokenizer ,
48
+ )
55
49
from aidial_adapter_openai .utils .log_config import configure_loggers , logger
56
50
from aidial_adapter_openai .utils .parsers import completions_parser , parse_body
57
51
from aidial_adapter_openai .utils .streaming import create_server_response
@@ -68,43 +62,88 @@ async def lifespan(app: FastAPI):
68
62
await get_http_client ().aclose ()
69
63
70
64
71
- app = FastAPI (lifespan = lifespan )
65
+ def create_app (
66
+ app_config : ApplicationConfig | None = None ,
67
+ to_init_telemetry : bool = True ,
68
+ to_configure_loggers : bool = True ,
69
+ ) -> FastAPI :
70
+ app = FastAPI (lifespan = lifespan )
71
+
72
+ if app_config is None :
73
+ app_config = ApplicationConfig .from_env ()
74
+
75
+ app .state .app_config = app_config
76
+
77
+ if to_init_telemetry :
78
+ init_telemetry (app , TelemetryConfig ())
72
79
80
+ if to_configure_loggers :
81
+ configure_loggers ()
73
82
74
- init_telemetry (app , TelemetryConfig ())
75
- configure_loggers ()
83
+ return app
76
84
77
85
78
- def get_api_version (request : Request ):
86
+ def get_app_config (request : Request ) -> ApplicationConfig :
87
+ return request .app .state .app_config
88
+
89
+
90
+ def get_api_version (request : Request ) -> str :
79
91
api_version = request .query_params .get ("api-version" , "" )
80
- api_version = API_VERSIONS_MAPPING .get (api_version , api_version )
92
+ app_config = get_app_config (request )
93
+ api_version = app_config .API_VERSIONS_MAPPING .get (api_version , api_version )
81
94
82
95
if api_version == "" :
83
96
raise InvalidRequestError ("api-version is a required query parameter" )
84
97
85
98
return api_version
86
99
87
100
101
+ def _get_image_tokenizer (
102
+ deployment_id : str , app_config : ApplicationConfig
103
+ ) -> ImageTokenizer :
104
+ image_tokenizer = get_image_tokenizer (deployment_id , app_config )
105
+ if not image_tokenizer :
106
+ raise RuntimeError (
107
+ f"No image tokenizer found for deployment { deployment_id } "
108
+ )
109
+ return image_tokenizer
110
+
111
+
112
+ app = create_app ()
113
+
114
+
88
115
@app .post ("/openai/deployments/{deployment_id:path}/chat/completions" )
89
- async def chat_completion (deployment_id : str , request : Request ):
116
+ async def chat_completion (
117
+ deployment_id : str ,
118
+ request : Request ,
119
+ app_config : Annotated [ApplicationConfig , Depends (get_app_config )],
120
+ ):
90
121
91
122
data = await parse_body (request )
92
123
93
124
is_stream = bool (data .get ("stream" ))
94
125
95
- emulate_streaming = deployment_id in NON_STREAMING_DEPLOYMENTS and is_stream
126
+ emulate_streaming = (
127
+ deployment_id in app_config .NON_STREAMING_DEPLOYMENTS and is_stream
128
+ )
96
129
97
130
if emulate_streaming :
98
131
data ["stream" ] = False
99
132
100
133
return create_server_response (
101
134
emulate_streaming ,
102
- await call_chat_completion (deployment_id , data , is_stream , request ),
135
+ await call_chat_completion (
136
+ deployment_id , data , is_stream , request , app_config
137
+ ),
103
138
)
104
139
105
140
106
141
async def call_chat_completion (
107
- deployment_id : str , data : dict , is_stream : bool , request : Request
142
+ deployment_id : str ,
143
+ data : dict ,
144
+ is_stream : bool ,
145
+ request : Request ,
146
+ app_config : ApplicationConfig ,
108
147
):
109
148
110
149
# Azure OpenAI deployments ignore "model" request field,
@@ -129,56 +168,62 @@ async def call_chat_completion(
129
168
creds ,
130
169
api_version ,
131
170
deployment_id ,
171
+ app_config ,
132
172
)
133
-
134
- if deployment_id in DALLE3_DEPLOYMENTS :
173
+ if deployment_id in app_config .DALLE3_DEPLOYMENTS :
135
174
storage = create_file_storage ("images" , request .headers )
136
175
return await dalle3_chat_completion (
137
176
data ,
138
177
upstream_endpoint ,
139
178
creds ,
140
179
is_stream ,
141
180
storage ,
142
- DALLE3_AZURE_API_VERSION ,
181
+ app_config . DALLE3_AZURE_API_VERSION ,
143
182
)
144
183
145
- if deployment_id in MISTRAL_DEPLOYMENTS :
184
+ if deployment_id in app_config . MISTRAL_DEPLOYMENTS :
146
185
return await mistral_chat_completion (data , upstream_endpoint , creds )
147
186
148
- if deployment_id in DATABRICKS_DEPLOYMENTS :
187
+ if deployment_id in app_config . DATABRICKS_DEPLOYMENTS :
149
188
return await databricks_chat_completion (data , upstream_endpoint , creds )
150
189
151
- text_tokenizer_model = MODEL_ALIASES .get (deployment_id , deployment_id )
190
+ text_tokenizer_model = app_config .MODEL_ALIASES .get (
191
+ deployment_id , deployment_id
192
+ )
152
193
153
- if image_tokenizer := get_image_tokenizer (deployment_id ):
154
- storage = create_file_storage ("images" , request .headers )
194
+ if deployment_id in app_config .GPT4_VISION_DEPLOYMENTS :
195
+ tokenizer = MultiModalTokenizer (
196
+ "gpt-4" , _get_image_tokenizer (deployment_id , app_config )
197
+ )
198
+ return await gpt4_vision_chat_completion (
199
+ data ,
200
+ deployment_id ,
201
+ upstream_endpoint ,
202
+ creds ,
203
+ is_stream ,
204
+ create_file_storage ("images" , request .headers ),
205
+ api_version ,
206
+ tokenizer ,
207
+ )
155
208
156
- if deployment_id in GPT4_VISION_DEPLOYMENTS :
157
- tokenizer = MultiModalTokenizer ("gpt-4" , image_tokenizer )
158
- return await gpt4_vision_chat_completion (
159
- data ,
160
- deployment_id ,
161
- upstream_endpoint ,
162
- creds ,
163
- is_stream ,
164
- storage ,
165
- api_version ,
166
- tokenizer ,
167
- )
168
- else :
169
- tokenizer = MultiModalTokenizer (
170
- text_tokenizer_model , image_tokenizer
171
- )
172
- return await gpt4o_chat_completion (
173
- data ,
174
- deployment_id ,
175
- upstream_endpoint ,
176
- creds ,
177
- is_stream ,
178
- storage ,
179
- api_version ,
180
- tokenizer ,
181
- )
209
+ if deployment_id in (
210
+ * app_config .GPT4O_DEPLOYMENTS ,
211
+ * app_config .GPT4O_MINI_DEPLOYMENTS ,
212
+ ):
213
+ tokenizer = MultiModalTokenizer (
214
+ text_tokenizer_model ,
215
+ _get_image_tokenizer (deployment_id , app_config ),
216
+ )
217
+ return await gpt4o_chat_completion (
218
+ data ,
219
+ deployment_id ,
220
+ upstream_endpoint ,
221
+ creds ,
222
+ is_stream ,
223
+ create_file_storage ("images" , request .headers ),
224
+ api_version ,
225
+ tokenizer ,
226
+ )
182
227
183
228
tokenizer = PlainTextTokenizer (model = text_tokenizer_model )
184
229
return await gpt_chat_completion (
@@ -192,7 +237,11 @@ async def call_chat_completion(
192
237
193
238
194
239
@app .post ("/openai/deployments/{deployment_id:path}/embeddings" )
195
- async def embedding (deployment_id : str , request : Request ):
240
+ async def embedding (
241
+ deployment_id : str ,
242
+ request : Request ,
243
+ app_config : Annotated [ApplicationConfig , Depends (get_app_config )],
244
+ ):
196
245
data = await parse_body (request )
197
246
198
247
# See note for /chat/completions endpoint
@@ -202,7 +251,7 @@ async def embedding(deployment_id: str, request: Request):
202
251
api_version = get_api_version (request )
203
252
upstream_endpoint = request .headers ["X-UPSTREAM-ENDPOINT" ]
204
253
205
- if deployment_id in AZURE_AI_VISION_DEPLOYMENTS :
254
+ if deployment_id in app_config . AZURE_AI_VISION_DEPLOYMENTS :
206
255
storage = create_file_storage ("images" , request .headers )
207
256
return await azure_ai_vision_embeddings (
208
257
creds , deployment_id , upstream_endpoint , storage , data
0 commit comments