Skip to content

Commit

Permalink
feat: add message truncation for multi modal (#150)
Browse files Browse the repository at this point in the history
* add truncation for multi modal embeddings

* remove redundant check

* Generalize prompt truncate

* Fix tests, and error with token count

* forgotten readme revert

* rename tests, since function has been renamed

* minor function renaming

* large refactoring

* Move back transformations to multimodal messages, save image parameters to reuse this info in token calculation

* minor refactor

* Fix due to PR comments

* Minor comment fix

* make test actually work, fix tests
  • Loading branch information
roman-romanov-o authored Sep 23, 2024
1 parent 2dd4667 commit b32e7ef
Show file tree
Hide file tree
Showing 17 changed files with 1,019 additions and 374 deletions.
9 changes: 6 additions & 3 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
)
from aidial_adapter_openai.utils.reflection import call_with_extra_body
from aidial_adapter_openai.utils.storage import create_file_storage
from aidial_adapter_openai.utils.tokens import Tokenizer
from aidial_adapter_openai.utils.tokenizer import (
MultiModalTokenizer,
PlainTextTokenizer,
)


@asynccontextmanager
Expand Down Expand Up @@ -165,9 +168,8 @@ async def chat_completion(deployment_id: str, request: Request):
)

openai_model_name = MODEL_ALIASES.get(deployment_id, deployment_id)
tokenizer = Tokenizer(model=openai_model_name)

if deployment_id in GPT4O_DEPLOYMENTS:
tokenizer = MultiModalTokenizer(openai_model_name)
storage = create_file_storage("images", request.headers)
return await handle_exceptions(
gpt4o_chat_completion(
Expand All @@ -182,6 +184,7 @@ async def chat_completion(deployment_id: str, request: Request):
)
)

tokenizer = PlainTextTokenizer(model=openai_model_name)
return await handle_exceptions(
gpt_chat_completion(
data,
Expand Down
41 changes: 31 additions & 10 deletions aidial_adapter_openai/gpt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Tuple, cast

from aidial_sdk.exceptions import InvalidRequestError
from fastapi.responses import StreamingResponse
from openai import AsyncStream
Expand All @@ -14,7 +16,24 @@
generate_stream,
map_stream,
)
from aidial_adapter_openai.utils.tokens import Tokenizer, truncate_prompt
from aidial_adapter_openai.utils.tokenizer import PlainTextTokenizer
from aidial_adapter_openai.utils.truncate_prompt import (
DiscardedMessages,
TruncatedTokens,
truncate_prompt,
)


def plain_text_truncate_prompt(
messages: List[dict], max_prompt_tokens: int, tokenizer: PlainTextTokenizer
) -> Tuple[List[dict], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.calculate_message_tokens,
is_system_message=lambda message: message["role"] == "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST,
)


async def gpt_chat_completion(
Expand All @@ -23,9 +42,10 @@ async def gpt_chat_completion(
upstream_endpoint: str,
creds: OpenAICreds,
api_version: str,
tokenizer: Tokenizer,
tokenizer: PlainTextTokenizer,
):
discarded_messages = None
prompt_tokens = None
if "max_prompt_tokens" in data:
max_prompt_tokens = data["max_prompt_tokens"]
if not isinstance(max_prompt_tokens, int):
Expand All @@ -38,27 +58,28 @@ async def gpt_chat_completion(
)
del data["max_prompt_tokens"]

data["messages"], discarded_messages = truncate_prompt(
tokenizer, data["messages"], max_prompt_tokens
data["messages"], discarded_messages, prompt_tokens = (
plain_text_truncate_prompt(
messages=cast(List[dict], data["messages"]),
max_prompt_tokens=max_prompt_tokens,
tokenizer=tokenizer,
)
)

client = chat_completions_parser.parse(upstream_endpoint).get_client(
{**creds, "api_version": api_version}
)

response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
await call_with_extra_body(client.chat.completions.create, data)
)

if isinstance(response, AsyncStream):

return StreamingResponse(
to_openai_sse_stream(
generate_stream(
get_prompt_tokens=lambda: tokenizer.calculate_prompt_tokens(
data["messages"]
),
tokenize=tokenizer.calculate_tokens,
get_prompt_tokens=lambda: prompt_tokens
or tokenizer.calculate_prompt_tokens(data["messages"]),
tokenize=tokenizer.calculate_text_tokens,
deployment=deployment_id,
discarded_messages=discarded_messages,
stream=map_stream(chunk_to_dict, response),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import mimetypes
from typing import List, Optional, Set, Tuple, cast
from typing import Optional

from pydantic import BaseModel

from aidial_adapter_openai.gpt4_multi_modal.image_tokenizer import (
tokenize_image,
)
from aidial_adapter_openai.gpt4_multi_modal.messages import (
create_image_message,
create_text_message,
)
from aidial_adapter_openai.utils.image_data_url import ImageDataURL
from aidial_adapter_openai.utils.image import ImageDataURL
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.storage import (
FileStorage,
Expand Down Expand Up @@ -104,76 +97,3 @@ def fail(message: str) -> ImageFail:
except Exception as e:
logger.error(f"Failed to download the image: {e}")
return fail("failed to download the attachment")


async def transform_message(
file_storage: Optional[FileStorage], message: dict
) -> Tuple[dict, int] | List[ImageFail]:
content = message.get("content", "")
custom_content = message.get("custom_content", {})
attachments = custom_content.get("attachments", [])

message = {k: v for k, v in message.items() if k != "custom_content"}

if len(attachments) == 0:
return message, 0

logger.debug(f"original attachments: {attachments}")

download_results: List[ImageDataURL | ImageFail] = [
await download_image(file_storage, attachment)
for attachment in attachments
]

logger.debug(f"download results: {download_results}")

errors: List[ImageFail] = [
res for res in download_results if isinstance(res, ImageFail)
]

if errors:
logger.error(f"download errors: {errors}")
return errors

image_urls: List[ImageDataURL] = cast(List[ImageDataURL], download_results)

image_tokens: List[int] = []
image_messages: List[dict] = []

for image_url in image_urls:
tokens, detail = tokenize_image(image_url, "auto")
image_tokens.append(tokens)
image_messages.append(create_image_message(image_url, detail))

total_image_tokens = sum(image_tokens)

logger.debug(f"image tokens: {image_tokens}")

sub_messages: List[dict] = [create_text_message(content)] + image_messages

return {**message, "content": sub_messages}, total_image_tokens


async def transform_messages(
file_storage: Optional[FileStorage], messages: List[dict]
) -> Tuple[List[dict], int] | str:
image_tokens = 0
new_messages: List[dict] = []
errors: Set[ImageFail] = set()

for message in messages:
result = await transform_message(file_storage, message)
if isinstance(result, list):
errors.update(result)
else:
new_message, tokens = result
new_messages.append(new_message)
image_tokens += tokens

if errors:
msg = "The following file attachments failed to process:"
for idx, error in enumerate(errors, start=1):
msg += f"\n{idx}. {error.name}: {error.message}"
return msg

return new_messages, image_tokens
96 changes: 69 additions & 27 deletions aidial_adapter_openai/gpt4_multi_modal/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,28 @@
Dict,
List,
Optional,
Tuple,
TypeVar,
cast,
)

import aiohttp
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError, RequestValidationError
from aidial_sdk.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse

from aidial_adapter_openai.gpt4_multi_modal.download import (
from aidial_adapter_openai.gpt4_multi_modal.attachment import (
SUPPORTED_FILE_EXTS,
transform_messages,
)
from aidial_adapter_openai.gpt4_multi_modal.gpt4_vision import (
convert_gpt4v_to_gpt4_chunk,
)
from aidial_adapter_openai.gpt4_multi_modal.transformation import (
transform_messages,
)
from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.multi_modal_message import MultiModalMessage
from aidial_adapter_openai.utils.sse_stream import (
parse_openai_sse_stream,
to_openai_sse_stream,
Expand All @@ -36,7 +40,12 @@
map_stream,
prepend_to_stream,
)
from aidial_adapter_openai.utils.tokens import Tokenizer
from aidial_adapter_openai.utils.tokenizer import MultiModalTokenizer
from aidial_adapter_openai.utils.truncate_prompt import (
DiscardedMessages,
TruncatedTokens,
truncate_prompt,
)

# The built-in default max_tokens is 16 tokens,
# which is too small for most image-to-text use cases.
Expand Down Expand Up @@ -111,6 +120,22 @@ async def predict_non_stream(
return await response.json()


def multi_modal_truncate_prompt(
messages: List[MultiModalMessage],
max_prompt_tokens: int,
initial_prompt_tokens: int,
tokenizer: MultiModalTokenizer,
) -> Tuple[List[MultiModalMessage], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.calculate_message_tokens,
is_system_message=lambda message: message.raw_message["role"]
== "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=initial_prompt_tokens,
)


async def gpt4o_chat_completion(
request: Any,
deployment: str,
Expand All @@ -119,7 +144,7 @@ async def gpt4o_chat_completion(
is_stream: bool,
file_storage: Optional[FileStorage],
api_version: str,
tokenizer: Tokenizer,
tokenizer: MultiModalTokenizer,
) -> Response:
return await chat_completion(
request,
Expand Down Expand Up @@ -152,7 +177,7 @@ async def gpt4_vision_chat_completion(
is_stream,
file_storage,
api_version,
Tokenizer("gpt-4"),
MultiModalTokenizer("gpt-4"),
convert_gpt4v_to_gpt4_chunk,
GPT4V_DEFAULT_MAX_TOKENS,
)
Expand All @@ -166,11 +191,10 @@ async def chat_completion(
is_stream: bool,
file_storage: Optional[FileStorage],
api_version: str,
tokenizer: Tokenizer,
tokenizer: MultiModalTokenizer,
response_transformer: Callable[[dict], dict | None],
default_max_tokens: int | None,
default_max_tokens: Optional[int],
) -> Response:

if request.get("n", 1) > 1:
raise RequestValidationError("The deployment doesn't support n > 1")

Expand All @@ -180,26 +204,39 @@ async def chat_completion(

api_url = f"{upstream_endpoint}?api-version={api_version}"

result = await transform_messages(file_storage, messages)

if isinstance(result, str):
logger.error(f"Failed to prepare request: {result}")

transform_result = await transform_messages(file_storage, messages)
if isinstance(transform_result, DialException):
logger.error(f"Failed to prepare request: {transform_result.message}")
chunk = create_stage_chunk("Usage", USAGE, is_stream)

exc = InvalidRequestError(message=result, display_message=result)

return create_response_from_chunk(chunk, exc, is_stream)

new_messages, prompt_image_tokens = result

prompt_text_tokens = tokenizer.calculate_prompt_tokens(messages)
estimated_prompt_tokens = prompt_text_tokens + prompt_image_tokens
return create_response_from_chunk(chunk, transform_result, is_stream)

multi_modal_messages = transform_result
discarded_messages = None
max_prompt_tokens = request.pop("max_prompt_tokens", None)
if max_prompt_tokens is not None:
multi_modal_messages, discarded_messages, estimated_prompt_tokens = (
multi_modal_truncate_prompt(
messages=multi_modal_messages,
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST,
tokenizer=tokenizer,
)
)
logger.debug(
f"prompt tokens after truncation: {estimated_prompt_tokens}"
)
else:
estimated_prompt_tokens = tokenizer.calculate_prompt_tokens(
multi_modal_messages
)
logger.debug(
f"prompt tokens without truncation: {estimated_prompt_tokens}"
)

request = {
**request,
"max_tokens": request.get("max_tokens") or default_max_tokens,
"messages": new_messages,
"messages": [m.raw_message for m in multi_modal_messages],
}

headers = get_auth_headers(creds)
Expand All @@ -221,9 +258,9 @@ def debug_print(chunk: T) -> T:
debug_print,
generate_stream(
get_prompt_tokens=lambda: estimated_prompt_tokens,
tokenize=tokenizer.calculate_tokens,
tokenize=tokenizer.calculate_text_tokens,
deployment=deployment,
discarded_messages=None,
discarded_messages=discarded_messages,
stream=map_stream(
response_transformer,
parse_openai_sse_stream(response),
Expand All @@ -249,14 +286,19 @@ def debug_print(chunk: T) -> T:
content = response["choices"][0]["message"].get("content") or ""
usage = response["usage"]

if discarded_messages:
response |= {
"statistics": {"discarded_messages": discarded_messages}
}

actual_prompt_tokens = usage["prompt_tokens"]
if actual_prompt_tokens != estimated_prompt_tokens:
logger.warning(
f"Estimated prompt tokens ({estimated_prompt_tokens}) don't match the actual ones ({actual_prompt_tokens})"
)

actual_completion_tokens = usage["completion_tokens"]
estimated_completion_tokens = tokenizer.calculate_tokens(content)
estimated_completion_tokens = tokenizer.calculate_text_tokens(content)
if actual_completion_tokens != estimated_completion_tokens:
logger.warning(
f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})"
Expand Down
Loading

0 comments on commit b32e7ef

Please sign in to comment.