Skip to content

Commit

Permalink
fix: failing with error for incorrect GPT4V request (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jul 9, 2024
1 parent 98815a7 commit 980caf4
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 122 deletions.
10 changes: 2 additions & 8 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.telemetry.init import init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
from aidial_sdk.utils.errors import json_error
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from openai import APIConnectionError, APIStatusError, APITimeoutError
Expand Down Expand Up @@ -35,6 +34,7 @@
chat_completion as mistral_chat_completion,
)
from aidial_adapter_openai.utils.auth import get_credentials
from aidial_adapter_openai.utils.errors import dial_exception_to_json_error
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.parsers import (
Expand Down Expand Up @@ -222,13 +222,7 @@ async def embedding(deployment_id: str, request: Request):
def exception_handler(request: Request, exc: DialException):
return JSONResponse(
status_code=exc.status_code,
content=json_error(
message=exc.message,
type=exc.type,
param=exc.param,
code=exc.code,
display_message=exc.display_message,
),
content=dial_exception_to_json_error(exc),
)


Expand Down
28 changes: 8 additions & 20 deletions aidial_adapter_openai/dalle3.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncIterator, Optional

import aiohttp
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.utils.errors import json_error
from fastapi.responses import JSONResponse, Response, StreamingResponse

from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
from aidial_adapter_openai.utils.sse_stream import END_CHUNK
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.storage import FileStorage
from aidial_adapter_openai.utils.streaming import (
build_chunk,
format_chunk,
generate_id,
)
from aidial_adapter_openai.utils.streaming import build_chunk, generate_id

IMG_USAGE = {
"prompt_tokens": 0,
Expand Down Expand Up @@ -73,18 +69,10 @@ def build_custom_content(base64_image: str, revised_prompt: str) -> Any:

async def generate_stream(
id: str, created: str, custom_content: Any
) -> AsyncGenerator[Any, Any]:
yield format_chunk(
build_chunk(id, None, {"role": "assistant"}, created, True)
)

yield format_chunk(build_chunk(id, None, custom_content, created, True))

yield format_chunk(
build_chunk(id, "stop", {}, created, True, usage=IMG_USAGE)
)

yield END_CHUNK
) -> AsyncIterator[dict]:
yield build_chunk(id, None, {"role": "assistant"}, created, True)
yield build_chunk(id, None, custom_content, created, True)
yield build_chunk(id, "stop", {}, created, True, usage=IMG_USAGE)


def get_user_prompt(data: Any):
Expand Down Expand Up @@ -166,6 +154,6 @@ async def chat_completion(
)
else:
return StreamingResponse(
generate_stream(id, created, custom_content),
to_openai_sse_stream(generate_stream(id, created, custom_content)),
media_type="text/event-stream",
)
14 changes: 5 additions & 9 deletions aidial_adapter_openai/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,15 @@
from aidial_adapter_openai.utils.parsers import parse_deployment_list

MODEL_ALIASES: Dict[str, str] = json.loads(os.getenv("MODEL_ALIASES", "{}"))
DALLE3_DEPLOYMENTS = parse_deployment_list(
os.getenv("DALLE3_DEPLOYMENTS") or ""
)
DALLE3_DEPLOYMENTS = parse_deployment_list(os.getenv("DALLE3_DEPLOYMENTS"))
GPT4_VISION_DEPLOYMENTS = parse_deployment_list(
os.getenv("GPT4_VISION_DEPLOYMENTS") or ""
)
MISTRAL_DEPLOYMENTS = parse_deployment_list(
os.getenv("MISTRAL_DEPLOYMENTS") or ""
os.getenv("GPT4_VISION_DEPLOYMENTS")
)
MISTRAL_DEPLOYMENTS = parse_deployment_list(os.getenv("MISTRAL_DEPLOYMENTS"))
DATABRICKS_DEPLOYMENTS = parse_deployment_list(
os.getenv("DATABRICKS_DEPLOYMENTS") or ""
os.getenv("DATABRICKS_DEPLOYMENTS")
)
GPT4O_DEPLOYMENTS = parse_deployment_list(os.getenv("GPT4O_DEPLOYMENTS") or "")
GPT4O_DEPLOYMENTS = parse_deployment_list(os.getenv("GPT4O_DEPLOYMENTS"))
API_VERSIONS_MAPPING: Dict[str, str] = json.loads(
os.getenv("API_VERSIONS_MAPPING", "{}")
)
Expand Down
26 changes: 13 additions & 13 deletions aidial_adapter_openai/gpt4_multi_modal/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
)
from aidial_adapter_openai.utils.storage import FileStorage
from aidial_adapter_openai.utils.streaming import (
create_error_response,
create_response_from_chunk,
create_stage_chunk,
generate_stream,
map_stream,
prepend_to_stream,
Expand Down Expand Up @@ -189,19 +190,18 @@ async def chat_completion(
result = await transform_messages(file_storage, messages)

if isinstance(result, str):
logger.debug(f"Failed to prepare request: {result}")
logger.error(f"Failed to prepare request: {result}")

if file_storage is not None:
# Report user-level error if the request came from the chat
error_message = result + "\n\n" + USAGE
return create_error_response(error_message, is_stream)
else:
# Throw an error if the request came from the API
raise DialException(
status_code=400,
message=result,
type="invalid_request_error",
)
chunk = create_stage_chunk("Usage", USAGE, is_stream)

exc = DialException(
status_code=400,
message=result,
display_message=result,
type="invalid_request_error",
)

return create_response_from_chunk(chunk, exc, is_stream)

new_messages, prompt_image_tokens = result

Expand Down
91 changes: 55 additions & 36 deletions aidial_adapter_openai/gpt4_multi_modal/download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import mimetypes
from typing import Dict, List, Optional, Tuple, cast
from typing import List, Optional, Set, Tuple, cast

from pydantic import BaseModel

from aidial_adapter_openai.gpt4_multi_modal.image_tokenizer import (
tokenize_image,
Expand All @@ -14,7 +16,6 @@
FileStorage,
download_file_as_base64,
)
from aidial_adapter_openai.utils.text import format_ordinal

# Officially supported image types by GPT-4 Vision, GPT-4o
SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp", "image/gif"]
Expand All @@ -23,30 +24,55 @@

def guess_attachment_type(attachment: dict) -> Optional[str]:
type = attachment.get("type")
if type is None:
return None

if "octet-stream" in type:
if type is None or "octet-stream" in type:
# It's an arbitrary binary file. Trying to guess the type from the URL.
url = attachment.get("url")
if url is not None:
url_type = mimetypes.guess_type(url)[0]
if url_type is not None:
return url_type
return None

return type


class ImageFail(BaseModel):
class Config:
frozen = True

name: str
message: str


async def get_attachment_name(
file_storage: Optional[FileStorage], attachment: dict
) -> str:
if "data" in attachment:
return attachment.get("title") or "data attachment"

if "url" in attachment:
attachment_link = attachment["url"]
if file_storage is not None:
return await file_storage.get_human_readable_name(attachment_link)
return attachment_link

return "invalid attachment"


async def download_image(
file_storage: Optional[FileStorage], attachment: dict
) -> ImageDataURL | str:
) -> ImageDataURL | ImageFail:
name = await get_attachment_name(file_storage, attachment)

def fail(message: str) -> ImageFail:
return ImageFail(name=name, message=message)

try:
type = guess_attachment_type(attachment)
if type is None:
return "Can't derive media type of the attachment"
return fail("can't derive media type of the attachment")
elif type not in SUPPORTED_IMAGE_TYPES:
return f"The attachment isn't one of the supported types: {type}"
return fail("the attachment is not one of the supported types")

if "data" in attachment:
return ImageDataURL(type=type, data=attachment["data"])
Expand All @@ -56,12 +82,11 @@ async def download_image(

image_url = ImageDataURL.from_data_url(attachment_link)
if image_url is not None:
if image_url.type in SUPPORTED_IMAGE_TYPES:
return image_url
else:
return (
"The image attachment isn't one of the supported types"
if image_url.type not in SUPPORTED_IMAGE_TYPES:
return fail(
"the attachment is not one of the supported types"
)
return image_url

if file_storage is not None:
url = file_storage.attachment_link_to_url(attachment_link)
Expand All @@ -71,16 +96,16 @@ async def download_image(

return ImageDataURL(type=type, data=data)

return "Invalid attachment"
return fail("invalid attachment")

except Exception as e:
logger.debug(f"Failed to download image: {e}")
return "Failed to download image"
logger.error(f"Failed to download the image: {e}")
return fail("failed to download the attachment")


async def transform_message(
file_storage: Optional[FileStorage], message: dict
) -> Tuple[dict, int] | List[Tuple[int, str]]:
) -> Tuple[dict, int] | List[ImageFail]:
content = message.get("content", "")
custom_content = message.get("custom_content", {})
attachments = custom_content.get("attachments", [])
Expand All @@ -92,21 +117,19 @@ async def transform_message(

logger.debug(f"original attachments: {attachments}")

download_results: List[ImageDataURL | str] = [
download_results: List[ImageDataURL | ImageFail] = [
await download_image(file_storage, attachment)
for attachment in attachments
]

logger.debug(f"download results: {download_results}")

errors: List[Tuple[int, str]] = [
(idx, result)
for idx, result in enumerate(download_results)
if isinstance(result, str)
errors: List[ImageFail] = [
res for res in download_results if isinstance(res, ImageFail)
]

if len(errors) > 0:
logger.debug(f"download errors: {errors}")
if errors:
logger.error(f"download errors: {errors}")
return errors

image_urls: List[ImageDataURL] = cast(List[ImageDataURL], download_results)
Expand All @@ -131,27 +154,23 @@ async def transform_message(
async def transform_messages(
file_storage: Optional[FileStorage], messages: List[dict]
) -> Tuple[List[dict], int] | str:
new_messages: List[dict] = []
image_tokens = 0
new_messages: List[dict] = []
errors: Set[ImageFail] = set()

errors: Dict[int, List[Tuple[int, str]]] = {}

n = len(messages)
for idx, message in enumerate(messages):
for message in messages:
result = await transform_message(file_storage, message)
if isinstance(result, list):
errors[n - idx] = result
errors.update(result)
else:
new_message, tokens = result
new_messages.append(new_message)
image_tokens += tokens

if errors:
msg = "Some of the image attachments failed to download:"
for i, error in errors.items():
msg += f"\n- {format_ordinal(i)} message from end:"
for j, err in error:
msg += f"\n - {format_ordinal(j + 1)} attachment: {err}"
msg = "The following file attachments failed to process:"
for idx, error in enumerate(errors, start=1):
msg += f"\n{idx}. {error.name}: {error.message}"
return msg

return new_messages, image_tokens
12 changes: 12 additions & 0 deletions aidial_adapter_openai/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.utils.errors import json_error


def dial_exception_to_json_error(exc: DialException) -> dict:
return json_error(
message=exc.message,
type=exc.type,
param=exc.param,
code=exc.code,
display_message=exc.display_message,
)
11 changes: 3 additions & 8 deletions aidial_adapter_openai/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def parse(
completions_parser = CompletionsParser()


async def parse_body(
request: Request,
) -> Dict[str, Any]:
async def parse_body(request: Request) -> Dict[str, Any]:
try:
data = await request.json()
except JSONDecodeError as e:
Expand All @@ -114,8 +112,5 @@ async def parse_body(
return data


def parse_deployment_list(deployments: str) -> List[str]:
if deployments is None:
return []

return list(map(str.strip, deployments.split(",")))
def parse_deployment_list(deployments: str | None) -> List[str]:
return list(map(str.strip, (deployments or "").split(",")))
Loading

0 comments on commit 980caf4

Please sign in to comment.