Skip to content

Commit 9d76934

Browse files
authored
feat: support DALL-E 3 (#26) (#30)
1 parent 0c31697 commit 9d76934

File tree

6 files changed

+282
-6
lines changed

6 files changed

+282
-6
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ Copy `.env.example` to `.env` and customize it for your environment:
5757
|WEB_CONCURRENCY|1|Number of workers for the server|
5858
|AZURE_API_VERSION|2023-03-15-preview|The version API for requests to Azure OpenAI API|
5959
|MODEL_ALIASES|{}|Mapping request's deployment_id to [model name of tiktoken](https://github.com/openai/tiktoken/blob/main/tiktoken/model.py) for correct calculate of tokens. Example: `{"gpt-35-turbo":"gpt-3.5-turbo-0301"}`|
60+
|DIAL_USE_FILE_STORAGE|False|Save image model artifacts to DIAL File storage (DALL-E images are uploaded to the files storage and its base64 encodings are replaced with links to the storage)|
61+
|DIAL_URL||URL of the core DIAL server (required when DIAL_USE_FILE_STORAGE=True)|
62+
|DIAL_API_KEY||API Key for DIAL File storage (required when DIAL_USE_FILE_STORAGE=True)|
6063

6164
### Docker
6265

aidial_adapter_openai/app.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,19 @@
99
from openai import ChatCompletion, Embedding, error
1010
from openai.openai_object import OpenAIObject
1111

12+
from aidial_adapter_openai.images import text_to_image_chat_completion
1213
from aidial_adapter_openai.openai_override import OpenAIException
14+
from aidial_adapter_openai.utils.deployment_classifier import (
15+
is_text_to_image_deployment,
16+
)
1317
from aidial_adapter_openai.utils.exceptions import HTTPException
1418
from aidial_adapter_openai.utils.log_config import LogConfig
1519
from aidial_adapter_openai.utils.parsers import (
1620
ApiType,
1721
parse_body,
1822
parse_upstream,
1923
)
24+
from aidial_adapter_openai.utils.storage import FileStorage
2025
from aidial_adapter_openai.utils.streaming import generate_stream
2126
from aidial_adapter_openai.utils.tokens import discard_messages
2227
from aidial_adapter_openai.utils.versions import compare_versions
@@ -26,6 +31,22 @@
2631
model_aliases: Dict[str, str] = json.loads(os.getenv("MODEL_ALIASES", "{}"))
2732
azure_api_version = os.getenv("AZURE_API_VERSION", "2023-03-15-preview")
2833

34+
dial_use_file_storage = (
35+
os.getenv("DIAL_USE_FILE_STORAGE", "false").lower() == "true"
36+
)
37+
38+
file_storage = None
39+
if dial_use_file_storage:
40+
dial_url = os.getenv("DIAL_URL")
41+
dial_api_key = os.getenv("DIAL_API_KEY")
42+
43+
if not dial_url or not dial_api_key:
44+
raise ValueError(
45+
"DIAL_URL and DIAL_API_KEY environment variables must be initialized if DIAL_USE_FILE_STORAGE is true"
46+
)
47+
48+
file_storage = FileStorage(dial_url, "dalle", dial_api_key)
49+
2950

3051
async def handle_exceptions(call):
3152
try:
@@ -46,10 +67,16 @@ async def chat_completion(deployment_id: str, request: Request):
4667

4768
is_stream = data.get("stream", False)
4869
openai_model_name = model_aliases.get(deployment_id, deployment_id)
49-
dial_api_key = request.headers["X-UPSTREAM-KEY"]
70+
api_key = request.headers["X-UPSTREAM-KEY"]
71+
upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]
72+
73+
if is_text_to_image_deployment(deployment_id):
74+
return await text_to_image_chat_completion(
75+
data, upstream_endpoint, api_key, is_stream, file_storage
76+
)
5077

5178
api_base, upstream_deployment = parse_upstream(
52-
request.headers["X-UPSTREAM-ENDPOINT"], ApiType.CHAT_COMPLETION
79+
upstream_endpoint, ApiType.CHAT_COMPLETION
5380
)
5481

5582
api_version = azure_api_version
@@ -87,7 +114,7 @@ async def chat_completion(deployment_id: str, request: Request):
87114
response = await handle_exceptions(
88115
ChatCompletion().acreate(
89116
engine=upstream_deployment,
90-
api_key=dial_api_key,
117+
api_key=api_key,
91118
api_base=api_base,
92119
api_type="azure",
93120
api_version=api_version,
@@ -127,15 +154,15 @@ async def chat_completion(deployment_id: str, request: Request):
127154
async def embedding(deployment_id: str, request: Request):
128155
data = await parse_body(request)
129156

130-
dial_api_key = request.headers["X-UPSTREAM-KEY"]
157+
api_key = request.headers["X-UPSTREAM-KEY"]
131158
api_base, upstream_deployment = parse_upstream(
132159
request.headers["X-UPSTREAM-ENDPOINT"], ApiType.EMBEDDING
133160
)
134161

135162
return await handle_exceptions(
136163
Embedding().acreate(
137164
deployment_id=upstream_deployment,
138-
api_key=dial_api_key,
165+
api_key=api_key,
139166
api_base=api_base,
140167
api_type="azure",
141168
api_version=azure_api_version,

aidial_adapter_openai/images.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from typing import Any, AsyncGenerator, Optional
2+
3+
import aiohttp
4+
from fastapi.responses import JSONResponse, Response, StreamingResponse
5+
6+
from aidial_adapter_openai.utils.exceptions import HTTPException
7+
from aidial_adapter_openai.utils.storage import FileStorage, upload_base64_file
8+
from aidial_adapter_openai.utils.streaming import (
9+
END_CHUNK,
10+
build_chunk,
11+
chunk_format,
12+
generate_id,
13+
)
14+
15+
IMG_USAGE = {
16+
"completion_tokens": 0,
17+
"prompt_tokens": 1,
18+
"total_tokens": 1,
19+
}
20+
21+
22+
async def generate_image(api_url: str, api_key: str, user_prompt: str) -> Any:
23+
async with aiohttp.ClientSession() as session:
24+
async with session.post(
25+
api_url,
26+
json={"prompt": user_prompt, "response_format": "b64_json"},
27+
headers={"api-key": api_key},
28+
) as response:
29+
status_code = response.status
30+
31+
data = await response.json()
32+
33+
if status_code == 200:
34+
return data
35+
else:
36+
return JSONResponse(content=data, status_code=status_code)
37+
38+
39+
def build_custom_content(base64_image: str, revised_prompt: str) -> Any:
40+
return {
41+
"custom_content": {
42+
"attachments": [
43+
{"title": "Revised prompt", "data": revised_prompt},
44+
{"title": "Image", "type": "image/png", "data": base64_image},
45+
]
46+
}
47+
}
48+
49+
50+
async def generate_stream(
51+
id: str, created: str, custom_content: Any
52+
) -> AsyncGenerator[Any, Any]:
53+
yield chunk_format(
54+
build_chunk(id, None, {"role": "assistant"}, created, True)
55+
)
56+
57+
yield chunk_format(build_chunk(id, None, custom_content, created, True))
58+
59+
yield chunk_format(
60+
build_chunk(id, "stop", {}, created, True, usage=IMG_USAGE)
61+
)
62+
63+
yield END_CHUNK
64+
65+
66+
def get_user_prompt(data: Any):
67+
if (
68+
"messages" not in data
69+
or len(data["messages"]) == 0
70+
or "content" not in data["messages"][-1]
71+
or not data["messages"][-1]
72+
):
73+
raise HTTPException(
74+
"Your request is invalid", 400, "invalid_request_error"
75+
)
76+
77+
return data["messages"][-1]["content"]
78+
79+
80+
async def move_attachments_data_to_storage(
81+
custom_content: Any, file_storage: FileStorage
82+
):
83+
for attachment in custom_content["custom_content"]["attachments"]:
84+
if (
85+
"data" not in attachment
86+
or "type" not in attachment
87+
or not attachment["type"].startswith("image/")
88+
):
89+
continue
90+
91+
file_metadata = await upload_base64_file(
92+
file_storage, attachment["data"], attachment["type"]
93+
)
94+
image_url = file_metadata["path"] + "/" + file_metadata["name"]
95+
96+
del attachment["data"]
97+
attachment["url"] = image_url
98+
99+
100+
async def text_to_image_chat_completion(
101+
data: Any,
102+
upstream_endpoint: str,
103+
api_key: str,
104+
is_stream: bool,
105+
file_storage: Optional[FileStorage],
106+
) -> Response:
107+
if data.get("n", 1) > 1:
108+
raise HTTPException(
109+
status_code=422,
110+
message="The deployment doesn't support n > 1",
111+
type="invalid_request_error",
112+
)
113+
114+
api_url = upstream_endpoint + "?api-version=2023-12-01-preview"
115+
user_prompt = get_user_prompt(data)
116+
model_response = await generate_image(api_url, api_key, user_prompt)
117+
118+
if isinstance(model_response, JSONResponse):
119+
return model_response
120+
121+
base64_image = model_response["data"][0]["b64_json"]
122+
revised_prompt = model_response["data"][0]["revised_prompt"]
123+
124+
id = generate_id()
125+
created = model_response["created"]
126+
127+
custom_content = build_custom_content(base64_image, revised_prompt)
128+
129+
if file_storage is not None:
130+
await move_attachments_data_to_storage(custom_content, file_storage)
131+
132+
if not is_stream:
133+
return JSONResponse(
134+
content=build_chunk(
135+
id,
136+
"stop",
137+
{**custom_content, "role": "assistant"},
138+
created,
139+
False,
140+
usage=IMG_USAGE,
141+
)
142+
)
143+
else:
144+
return StreamingResponse(
145+
generate_stream(id, created, custom_content),
146+
media_type="text/event-stream",
147+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def is_text_to_image_deployment(deployment_id: str):
2+
return deployment_id.lower() == "dalle3"
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import base64
2+
import hashlib
3+
import io
4+
from typing import TypedDict
5+
6+
import aiohttp
7+
8+
from aidial_adapter_openai.utils.log_config import logger
9+
10+
11+
class FileMetadata(TypedDict):
12+
name: str
13+
type: str
14+
path: str
15+
contentLength: int
16+
contentType: str
17+
18+
19+
class FileStorage:
20+
base_url: str
21+
api_key: str
22+
23+
def __init__(self, dial_url: str, base_dir: str, api_key: str):
24+
self.base_url = f"{dial_url}/v1/files/{base_dir}"
25+
self.api_key = api_key
26+
27+
def auth_headers(self) -> dict[str, str]:
28+
return {"api-key": self.api_key}
29+
30+
@staticmethod
31+
def to_form_data(
32+
filename: str, content_type: str, content: bytes
33+
) -> aiohttp.FormData:
34+
data = aiohttp.FormData()
35+
data.add_field(
36+
"file",
37+
io.BytesIO(content),
38+
filename=filename,
39+
content_type=content_type,
40+
)
41+
return data
42+
43+
async def upload(
44+
self, filename: str, content_type: str, content: bytes
45+
) -> FileMetadata:
46+
async with aiohttp.ClientSession() as session:
47+
data = FileStorage.to_form_data(filename, content_type, content)
48+
async with session.post(
49+
self.base_url,
50+
data=data,
51+
headers=self.auth_headers(),
52+
) as response:
53+
response.raise_for_status()
54+
meta = await response.json()
55+
logger.debug(
56+
f"Uploaded file: path={self.base_url}, file={filename}, metadata={meta}"
57+
)
58+
return meta
59+
60+
61+
def _hash_digest(string: str) -> str:
62+
return hashlib.sha256(string.encode()).hexdigest()
63+
64+
65+
async def upload_base64_file(
66+
storage: FileStorage, data: str, content_type: str
67+
) -> FileMetadata:
68+
filename = _hash_digest(data)
69+
content: bytes = base64.b64decode(data)
70+
return await storage.upload(filename, content_type, content)

aidial_adapter_openai/utils/streaming.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,33 @@ def chunk_format(data: str | Mapping[str, Any]):
1717
return "data: " + json.dumps(data, separators=(",", ":")) + "\n\n"
1818

1919

20+
def generate_id():
21+
return "chatcmpl-" + str(uuid4())
22+
23+
24+
def build_chunk(
25+
id: str,
26+
finish_reason: Optional[str],
27+
delta: Any,
28+
created: str,
29+
is_stream,
30+
**extra
31+
):
32+
return {
33+
"id": id,
34+
"object": "chat.completion.chunk" if is_stream else "chat.completion",
35+
"created": created,
36+
"choices": [
37+
{
38+
"index": 0,
39+
"finish_reason": finish_reason,
40+
"delta": delta,
41+
}
42+
],
43+
**extra,
44+
}
45+
46+
2047
END_CHUNK = chunk_format("[DONE]")
2148

2249

@@ -83,7 +110,7 @@ async def generate_stream(
83110

84111
yield chunk_format(
85112
{
86-
"id": "chatcmpl-" + str(uuid4()),
113+
"id": generate_id(),
87114
"object": "chat.completion.chunk",
88115
"created": str(int(time())),
89116
"model": deployment,

0 commit comments

Comments
 (0)