Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: supported approximate tokenization of tools and functions #170

Merged
merged 11 commits into from
Nov 25, 2024
Merged
32 changes: 18 additions & 14 deletions aidial_adapter_openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,32 @@


def plain_text_truncate_prompt(
messages: List[dict], max_prompt_tokens: int, tokenizer: PlainTextTokenizer
request: dict,
messages: List[dict],
max_prompt_tokens: int,
tokenizer: PlainTextTokenizer,
) -> Tuple[List[dict], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.calculate_message_tokens,
message_tokens=tokenizer.tokenize_request_message,
is_system_message=lambda message: message["role"] == "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST,
initial_prompt_tokens=tokenizer.tokenize_request(request, []),
)


async def gpt_chat_completion(
data: dict,
request: dict,
deployment_id: str,
upstream_endpoint: str,
creds: OpenAICreds,
api_version: str,
tokenizer: PlainTextTokenizer,
):
discarded_messages = None
prompt_tokens = None
if "max_prompt_tokens" in data:
max_prompt_tokens = data["max_prompt_tokens"]
estimated_prompt_tokens = None
if "max_prompt_tokens" in request:
max_prompt_tokens = request["max_prompt_tokens"]
if not isinstance(max_prompt_tokens, int):
raise InvalidRequestError(
f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'",
Expand All @@ -54,11 +57,12 @@ async def gpt_chat_completion(
raise InvalidRequestError(
f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'",
)
del data["max_prompt_tokens"]
del request["max_prompt_tokens"]

data["messages"], discarded_messages, prompt_tokens = (
request["messages"], discarded_messages, estimated_prompt_tokens = (
plain_text_truncate_prompt(
messages=cast(List[dict], data["messages"]),
request=request,
messages=cast(List[dict], request["messages"]),
max_prompt_tokens=max_prompt_tokens,
tokenizer=tokenizer,
)
Expand All @@ -68,14 +72,14 @@ async def gpt_chat_completion(
{**creds, "api_version": api_version}
)
response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
await call_with_extra_body(client.chat.completions.create, data)
await call_with_extra_body(client.chat.completions.create, request)
)

if isinstance(response, AsyncIterator):
return generate_stream(
get_prompt_tokens=lambda: prompt_tokens
or tokenizer.calculate_prompt_tokens(data["messages"]),
tokenize=tokenizer.calculate_text_tokens,
get_prompt_tokens=lambda: estimated_prompt_tokens
or tokenizer.tokenize_request(request, request["messages"]),
tokenize_response=tokenizer.tokenize_response,
deployment=deployment_id,
discarded_messages=discarded_messages,
stream=map_stream(chunk_to_dict, response),
Expand Down
43 changes: 23 additions & 20 deletions aidial_adapter_openai/gpt4_multi_modal/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
ResourceProcessor,
)
from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
from aidial_adapter_openai.utils.chat_completion_response import (
ChatCompletionBlock,
)
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.multi_modal_message import MultiModalMessage
from aidial_adapter_openai.utils.sse_stream import parse_openai_sse_stream
Expand Down Expand Up @@ -116,18 +119,18 @@ async def predict_non_stream(


def multi_modal_truncate_prompt(
request: dict,
messages: List[MultiModalMessage],
max_prompt_tokens: int,
initial_prompt_tokens: int,
tokenizer: MultiModalTokenizer,
) -> Tuple[List[MultiModalMessage], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.calculate_message_tokens,
message_tokens=tokenizer.tokenize_request_message,
is_system_message=lambda message: message.raw_message["role"]
== "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=initial_prompt_tokens,
initial_prompt_tokens=tokenizer.tokenize_request(request, []),
)


Expand Down Expand Up @@ -215,18 +218,18 @@ async def chat_completion(
if max_prompt_tokens is not None:
multi_modal_messages, discarded_messages, estimated_prompt_tokens = (
multi_modal_truncate_prompt(
request=request,
messages=multi_modal_messages,
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST,
tokenizer=tokenizer,
)
)
logger.debug(
f"prompt tokens after truncation: {estimated_prompt_tokens}"
)
else:
estimated_prompt_tokens = tokenizer.calculate_prompt_tokens(
multi_modal_messages
estimated_prompt_tokens = tokenizer.tokenize_request(
request, multi_modal_messages
)
logger.debug(
f"prompt tokens without truncation: {estimated_prompt_tokens}"
Expand Down Expand Up @@ -255,7 +258,7 @@ def debug_print(chunk: T) -> T:
debug_print,
generate_stream(
get_prompt_tokens=lambda: estimated_prompt_tokens,
tokenize=tokenizer.calculate_text_tokens,
tokenize_response=tokenizer.tokenize_response,
deployment=deployment,
discarded_messages=discarded_messages,
stream=map_stream(
Expand All @@ -277,25 +280,25 @@ def debug_print(chunk: T) -> T:
type="invalid_response_error",
)

content = response["choices"][0]["message"].get("content") or ""
usage = response["usage"]

if discarded_messages:
response |= {
"statistics": {"discarded_messages": discarded_messages}
}

actual_prompt_tokens = usage["prompt_tokens"]
if actual_prompt_tokens != estimated_prompt_tokens:
logger.warning(
f"Estimated prompt tokens ({estimated_prompt_tokens}) don't match the actual ones ({actual_prompt_tokens})"
)
if usage := response.get("usage"):
actual_prompt_tokens = usage["prompt_tokens"]
if actual_prompt_tokens != estimated_prompt_tokens:
logger.warning(
f"Estimated prompt tokens ({estimated_prompt_tokens}) don't match the actual ones ({actual_prompt_tokens})"
)

actual_completion_tokens = usage["completion_tokens"]
estimated_completion_tokens = tokenizer.calculate_text_tokens(content)
if actual_completion_tokens != estimated_completion_tokens:
logger.warning(
f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})"
actual_completion_tokens = usage["completion_tokens"]
estimated_completion_tokens = tokenizer.tokenize_response(
ChatCompletionBlock(resp=response)
)
if actual_completion_tokens != estimated_completion_tokens:
logger.warning(
f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})"
)

return response
51 changes: 51 additions & 0 deletions aidial_adapter_openai/utils/chat_completion_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Iterable, Literal, Self

from aidial_sdk.utils.merge_chunks import merge_chat_completion_chunks
from pydantic import BaseModel


class ChatCompletionResponse(BaseModel):
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
message_key: Literal["delta", "message"]
resp: dict = {}

@property
def usage(self) -> Any | None:
return self.resp.get("usage")

@property
def is_empty(self) -> bool:
return self.resp == {}

@property
def finish_reasons(self) -> Iterable[Any]:
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
for choice in self.resp.get("choices") or []:
if (reason := choice.get("finish_reason")) is not None:
yield reason

@property
def has_finish_reason(self) -> bool:
return len(list(self.finish_reasons)) > 0

@property
def messages(self) -> Iterable[Any]:
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
for choice in self.resp.get("choices") or []:
if (message := choice.get(self.message_key)) is not None:
yield message

@property
def has_messages(self) -> bool:
return len(list(self.messages)) > 0


class ChatCompletionBlock(ChatCompletionResponse):
def __init__(self, **kwargs):
super().__init__(message_key="message", **kwargs)


class ChatCompletionStreamingChunk(ChatCompletionResponse):
def __init__(self, **kwargs):
super().__init__(message_key="delta", **kwargs)

def merge(self, chunk: dict) -> Self:
self.resp = merge_chat_completion_chunks(self.resp, chunk)
return self
29 changes: 0 additions & 29 deletions aidial_adapter_openai/utils/merge_chunks.py

This file was deleted.

Loading
Loading