diff --git a/aidial_adapter_openai/gpt.py b/aidial_adapter_openai/gpt.py index 5c141fb..5d6610d 100644 --- a/aidial_adapter_openai/gpt.py +++ b/aidial_adapter_openai/gpt.py @@ -23,19 +23,22 @@ def plain_text_truncate_prompt( - messages: List[dict], max_prompt_tokens: int, tokenizer: PlainTextTokenizer + request: dict, + messages: List[dict], + max_prompt_tokens: int, + tokenizer: PlainTextTokenizer, ) -> Tuple[List[dict], DiscardedMessages, TruncatedTokens]: return truncate_prompt( messages=messages, - message_tokens=tokenizer.calculate_message_tokens, + message_tokens=tokenizer.tokenize_request_message, is_system_message=lambda message: message["role"] == "system", max_prompt_tokens=max_prompt_tokens, - initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST, + initial_prompt_tokens=tokenizer.tokenize_request(request, []), ) async def gpt_chat_completion( - data: dict, + request: dict, deployment_id: str, upstream_endpoint: str, creds: OpenAICreds, @@ -43,9 +46,9 @@ async def gpt_chat_completion( tokenizer: PlainTextTokenizer, ): discarded_messages = None - prompt_tokens = None - if "max_prompt_tokens" in data: - max_prompt_tokens = data["max_prompt_tokens"] + estimated_prompt_tokens = None + if "max_prompt_tokens" in request: + max_prompt_tokens = request["max_prompt_tokens"] if not isinstance(max_prompt_tokens, int): raise InvalidRequestError( f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'", @@ -54,11 +57,12 @@ async def gpt_chat_completion( raise InvalidRequestError( f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'", ) - del data["max_prompt_tokens"] + del request["max_prompt_tokens"] - data["messages"], discarded_messages, prompt_tokens = ( + request["messages"], discarded_messages, estimated_prompt_tokens = ( plain_text_truncate_prompt( - messages=cast(List[dict], data["messages"]), + request=request, + messages=cast(List[dict], request["messages"]), max_prompt_tokens=max_prompt_tokens, tokenizer=tokenizer, ) @@ -68,14 +72,14 @@ async def gpt_chat_completion( {**creds, "api_version": api_version} ) response: AsyncStream[ChatCompletionChunk] | ChatCompletion = ( - await call_with_extra_body(client.chat.completions.create, data) + await call_with_extra_body(client.chat.completions.create, request) ) if isinstance(response, AsyncIterator): return generate_stream( - get_prompt_tokens=lambda: prompt_tokens - or tokenizer.calculate_prompt_tokens(data["messages"]), - tokenize=tokenizer.calculate_text_tokens, + get_prompt_tokens=lambda: estimated_prompt_tokens + or tokenizer.tokenize_request(request, request["messages"]), + tokenize_response=tokenizer.tokenize_response, deployment=deployment_id, discarded_messages=discarded_messages, stream=map_stream(chunk_to_dict, response), diff --git a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py index b6be728..26d82c7 100644 --- a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py +++ b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py @@ -25,6 +25,9 @@ ResourceProcessor, ) from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers +from aidial_adapter_openai.utils.chat_completion_response import ( + ChatCompletionBlock, +) from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.multi_modal_message import MultiModalMessage from aidial_adapter_openai.utils.sse_stream import parse_openai_sse_stream @@ -116,18 +119,18 @@ async def predict_non_stream( def multi_modal_truncate_prompt( + request: dict, messages: List[MultiModalMessage], max_prompt_tokens: int, - initial_prompt_tokens: int, tokenizer: MultiModalTokenizer, ) -> Tuple[List[MultiModalMessage], DiscardedMessages, TruncatedTokens]: return truncate_prompt( messages=messages, - message_tokens=tokenizer.calculate_message_tokens, + message_tokens=tokenizer.tokenize_request_message, is_system_message=lambda message: message.raw_message["role"] == "system", max_prompt_tokens=max_prompt_tokens, - initial_prompt_tokens=initial_prompt_tokens, + initial_prompt_tokens=tokenizer.tokenize_request(request, []), ) @@ -215,9 +218,9 @@ async def chat_completion( if max_prompt_tokens is not None: multi_modal_messages, discarded_messages, estimated_prompt_tokens = ( multi_modal_truncate_prompt( + request=request, messages=multi_modal_messages, max_prompt_tokens=max_prompt_tokens, - initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST, tokenizer=tokenizer, ) ) @@ -225,8 +228,8 @@ async def chat_completion( f"prompt tokens after truncation: {estimated_prompt_tokens}" ) else: - estimated_prompt_tokens = tokenizer.calculate_prompt_tokens( - multi_modal_messages + estimated_prompt_tokens = tokenizer.tokenize_request( + request, multi_modal_messages ) logger.debug( f"prompt tokens without truncation: {estimated_prompt_tokens}" @@ -255,7 +258,7 @@ def debug_print(chunk: T) -> T: debug_print, generate_stream( get_prompt_tokens=lambda: estimated_prompt_tokens, - tokenize=tokenizer.calculate_text_tokens, + tokenize_response=tokenizer.tokenize_response, deployment=deployment, discarded_messages=discarded_messages, stream=map_stream( @@ -277,25 +280,25 @@ def debug_print(chunk: T) -> T: type="invalid_response_error", ) - content = response["choices"][0]["message"].get("content") or "" - usage = response["usage"] - if discarded_messages: response |= { "statistics": {"discarded_messages": discarded_messages} } - actual_prompt_tokens = usage["prompt_tokens"] - if actual_prompt_tokens != estimated_prompt_tokens: - logger.warning( - f"Estimated prompt tokens ({estimated_prompt_tokens}) don't match the actual ones ({actual_prompt_tokens})" - ) + if usage := response.get("usage"): + actual_prompt_tokens = usage["prompt_tokens"] + if actual_prompt_tokens != estimated_prompt_tokens: + logger.warning( + f"Estimated prompt tokens ({estimated_prompt_tokens}) don't match the actual ones ({actual_prompt_tokens})" + ) - actual_completion_tokens = usage["completion_tokens"] - estimated_completion_tokens = tokenizer.calculate_text_tokens(content) - if actual_completion_tokens != estimated_completion_tokens: - logger.warning( - f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})" + actual_completion_tokens = usage["completion_tokens"] + estimated_completion_tokens = tokenizer.tokenize_response( + ChatCompletionBlock(resp=response) ) + if actual_completion_tokens != estimated_completion_tokens: + logger.warning( + f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})" + ) return response diff --git a/aidial_adapter_openai/utils/chat_completion_response.py b/aidial_adapter_openai/utils/chat_completion_response.py new file mode 100644 index 0000000..e3d7336 --- /dev/null +++ b/aidial_adapter_openai/utils/chat_completion_response.py @@ -0,0 +1,51 @@ +from typing import Any, Iterable, Literal, Self + +from aidial_sdk.utils.merge_chunks import merge_chat_completion_chunks +from pydantic import BaseModel + + +class ChatCompletionResponse(BaseModel): + message_key: Literal["delta", "message"] + resp: dict = {} + + @property + def usage(self) -> Any | None: + return self.resp.get("usage") + + @property + def is_empty(self) -> bool: + return self.resp == {} + + @property + def finish_reasons(self) -> Iterable[Any]: + for choice in self.resp.get("choices") or []: + if (reason := choice.get("finish_reason")) is not None: + yield reason + + @property + def has_finish_reason(self) -> bool: + return len(list(self.finish_reasons)) > 0 + + @property + def messages(self) -> Iterable[Any]: + for choice in self.resp.get("choices") or []: + if (message := choice.get(self.message_key)) is not None: + yield message + + @property + def has_messages(self) -> bool: + return len(list(self.messages)) > 0 + + +class ChatCompletionBlock(ChatCompletionResponse): + def __init__(self, **kwargs): + super().__init__(message_key="message", **kwargs) + + +class ChatCompletionStreamingChunk(ChatCompletionResponse): + def __init__(self, **kwargs): + super().__init__(message_key="delta", **kwargs) + + def merge(self, chunk: dict) -> Self: + self.resp = merge_chat_completion_chunks(self.resp, chunk) + return self diff --git a/aidial_adapter_openai/utils/merge_chunks.py b/aidial_adapter_openai/utils/merge_chunks.py deleted file mode 100644 index fc5119c..0000000 --- a/aidial_adapter_openai/utils/merge_chunks.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import TypeVar - -from aidial_sdk.utils.merge_chunks import merge - -_Chunk = TypeVar("_Chunk", bound=dict) - - -def merge_chunks(*chunks: _Chunk) -> _Chunk: - """ - The recursive merging procedure that avoids merging top-level atomic fields - (e.g. "id", "created", "model", "object", "system_fingerprint") and - instead chooses an _override_ merging strategy for such fields. - Non-atomic fields (e.g. "choice", "usage") are merged following - the standard recursive merging procedure. - """ - - assert len(chunks) > 0, "At least one chunk must be provided" - - target = chunks[0] - - for chunk in chunks[1:]: - source = chunk.copy() - for key, value in list(source.items()): - if not isinstance(value, (list, dict)) and value is not None: - target[key] = value - del source[key] - target = merge(target, source) - - return target diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index 038df0b..d677ef7 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -1,17 +1,21 @@ import logging from time import time -from typing import Any, AsyncIterator, Callable, Iterable, Optional, TypeVar +from typing import Any, AsyncIterator, Callable, Optional, TypeVar from uuid import uuid4 from aidial_sdk.exceptions import HTTPException as DialException +from aidial_sdk.utils.merge_chunks import merge_chat_completion_chunks from fastapi.responses import JSONResponse, Response, StreamingResponse from openai import APIError, APIStatusError from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel from aidial_adapter_openai.env import get_eliminate_empty_choices +from aidial_adapter_openai.utils.chat_completion_response import ( + ChatCompletionResponse, + ChatCompletionStreamingChunk, +) from aidial_adapter_openai.utils.log_config import logger -from aidial_adapter_openai.utils.merge_chunks import merge_chunks from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream ELIMINATE_EMPTY_CHOICES = get_eliminate_empty_choices() @@ -54,13 +58,13 @@ def build_chunk( async def generate_stream( *, get_prompt_tokens: Callable[[], int], - tokenize: Callable[[str], int], + tokenize_response: Callable[[ChatCompletionResponse], int], deployment: str, discarded_messages: Optional[list[int]], stream: AsyncIterator[dict], ) -> AsyncIterator[dict]: - noop_chunk = build_chunk( + empty_chunk = build_chunk( id=generate_id(), created=generate_created(), model=deployment, @@ -69,10 +73,11 @@ async def generate_stream( finish_reason=None, ) - def set_usage(chunk: dict | None, completions: Iterable[str]) -> dict: - chunk = chunk or noop_chunk - completion_tokens = sum(map(tokenize, completions)) + def set_usage(chunk: dict | None, resp: ChatCompletionResponse) -> dict: + completion_tokens = tokenize_response(resp) prompt_tokens = get_prompt_tokens() + + chunk = chunk or empty_chunk chunk["usage"] = { "completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, @@ -81,44 +86,32 @@ def set_usage(chunk: dict | None, completions: Iterable[str]) -> dict: return chunk def set_finish_reason(chunk: dict | None, finish_reason: str) -> dict: - chunk = chunk or noop_chunk + chunk = chunk or empty_chunk chunk["choices"] = chunk.get("choices") or [{"index": 0, "delta": {}}] chunk["choices"][0]["finish_reason"] = finish_reason return chunk def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict: - chunk = chunk or noop_chunk + chunk = chunk or empty_chunk chunk["statistics"] = {"discarded_messages": indices} return chunk - n_chunks = 0 last_chunk = None buffer_chunk = None + response_snapshot = ChatCompletionStreamingChunk() - completions: dict[int, str] = {} - found_finish_reason = False - found_usage = False error = None try: async for chunk in stream: - n_chunks += 1 + response_snapshot.merge(chunk) if buffer_chunk is not None: - chunk = merge_chunks(buffer_chunk, chunk) + chunk = merge_chat_completion_chunks(chunk, buffer_chunk) buffer_chunk = None choices = chunk.get("choices") or [] - for choice in choices: - index = choice["index"] - content = (choice.get("delta") or {}).get("content") or "" - - completions[index] = completions.get(index, "") + content - found_finish_reason |= bool(choice.get("finish_reason")) - - found_usage |= bool(chunk.get("usage")) - # Azure OpenAI returns an empty list of choices as a first chunk # when content filtering is enabled for a corresponding deployment. # The safety rating of the request is reported in this first chunk. @@ -141,25 +134,29 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict: ).json_error() if last_chunk is not None and buffer_chunk is not None: - last_chunk = merge_chunks(buffer_chunk, last_chunk) + last_chunk = merge_chat_completion_chunks(last_chunk, buffer_chunk) if discarded_messages is not None: last_chunk = set_discarded_messages(last_chunk, discarded_messages) - if not found_usage and (not error or completions): - last_chunk = set_usage(last_chunk, completions.values()) + if response_snapshot.usage is None and ( + not error or response_snapshot.has_messages + ): + last_chunk = set_usage(last_chunk, response_snapshot) if not error: - if n_chunks == 0: + has_finish_reason = response_snapshot.has_finish_reason + + if response_snapshot.is_empty: logger.warning("Received 0 chunks") - elif not found_finish_reason: + elif not has_finish_reason: logger.warning("Didn't receive chunk with the finish reason") - if not found_finish_reason: + if not has_finish_reason: last_chunk = set_finish_reason(last_chunk, "length") - if not found_usage: - last_chunk = set_usage(last_chunk, completions.values()) + if response_snapshot.usage is None: + last_chunk = set_usage(last_chunk, response_snapshot) if last_chunk: yield last_chunk diff --git a/aidial_adapter_openai/utils/tokenizer.py b/aidial_adapter_openai/utils/tokenizer.py index 539fcd7..3af0e94 100644 --- a/aidial_adapter_openai/utils/tokenizer.py +++ b/aidial_adapter_openai/utils/tokenizer.py @@ -2,12 +2,16 @@ Implemented based on the official recipe: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken """ +import json from abc import abstractmethod from typing import Any, Callable, Generic, List, TypeVar from aidial_sdk.exceptions import InternalServerError from tiktoken import Encoding, encoding_for_model +from aidial_adapter_openai.utils.chat_completion_response import ( + ChatCompletionResponse, +) from aidial_adapter_openai.utils.image_tokenizer import ImageTokenizer from aidial_adapter_openai.utils.multi_modal_message import MultiModalMessage @@ -15,6 +19,10 @@ class BaseTokenizer(Generic[MessageType]): + """ + Tokenizer for chat completion requests and responses. + """ + model: str encoding: Encoding TOKENS_PER_REQUEST = 3 @@ -30,11 +38,39 @@ def __init__(self, model: str) -> None: "or declare it as a model which doesn't require tokenization through tiktoken.", ) from e - def calculate_text_tokens(self, text: str) -> int: + def tokenize_text(self, text: str) -> int: return len(self.encoding.encode(text)) + def tokenize_response(self, resp: ChatCompletionResponse) -> int: + return sum(map(self._tokenize_response_message, resp.messages)) + + def _tokenize_object(self, obj: Any) -> int: + if not obj: + return 0 + + # OpenAI doesn't reveal tokenization algorithm for tools calls and function calls. + # An approximation is used instead - token count in the string repr of the objects. + text = ( + obj + if isinstance(obj, str) + else json.dumps(obj, separators=(",", ":")) + ) + return self.tokenize_text(text) + + def _tokenize_response_message(self, message: Any) -> int: + + tokens = 0 + + for key in ["content", "refusal", "function"]: + tokens += self._tokenize_object(message.get(key)) + + for tool_call in message.get("tool_calls") or []: + tokens += self._tokenize_object(tool_call.get("function")) + + return tokens + @property - def tokens_per_message(self) -> int: + def _tokens_per_request_message(self) -> int: """ Tokens, that are counter for each message, regardless of its content """ @@ -43,7 +79,7 @@ def tokens_per_message(self) -> int: return 3 @property - def tokens_per_name(self) -> int: + def _tokens_per_request_message_name(self) -> int: """ Tokens, that are counter for "name" field in message, if it's present """ @@ -51,23 +87,25 @@ def tokens_per_name(self) -> int: return -1 return 1 - def calculate_request_prompt_tokens(self, messages_tokens: int): - """ - Amount of tokens, that will be counted by API - is greater than actual sum of tokens of all messages - """ - return self.TOKENS_PER_REQUEST + messages_tokens + def tokenize_request( + self, original_request: dict, messages: List[MessageType] + ) -> int: + tokens = self.TOKENS_PER_REQUEST - def calculate_prompt_tokens(self, messages: List[MessageType]) -> int: - return self.calculate_request_prompt_tokens( - messages_tokens=sum(map(self.calculate_message_tokens, messages)) - ) + if original_request.get("function_call") != "none": + for func in original_request.get("function") or []: + tokens += self._tokenize_object(func) + + if original_request.get("tool_choice") != "none": + for tool in original_request.get("tools") or []: + tokens += self._tokenize_object(tool.get("function")) - def available_message_tokens(self, max_prompt_tokens: int): - return max_prompt_tokens - self.TOKENS_PER_REQUEST + tokens += sum(map(self.tokenize_request_message, messages)) + + return tokens @abstractmethod - def calculate_message_tokens(self, message: MessageType) -> int: + def tokenize_request_message(self, message: MessageType) -> int: pass @@ -122,11 +160,11 @@ def _handle_custom_content_part(self, content_part: Any): f"Use MultiModalTokenizer for messages with images" ) - def calculate_message_tokens(self, message: dict) -> int: - return self.tokens_per_message + _process_raw_message( + def tokenize_request_message(self, message: dict) -> int: + return self._tokens_per_request_message + _process_raw_message( raw_message=message, - tokens_per_name=self.tokens_per_name, - calculate_text_tokens=self.calculate_text_tokens, + tokens_per_name=self._tokens_per_request_message_name, + calculate_text_tokens=self.tokenize_text, handle_custom_content_part=self._handle_custom_content_part, ) @@ -138,14 +176,14 @@ def __init__(self, model: str, image_tokenizer: ImageTokenizer): super().__init__(model) self.image_tokenizer = image_tokenizer - def calculate_message_tokens(self, message: MultiModalMessage) -> int: - tokens = self.tokens_per_message + def tokenize_request_message(self, message: MultiModalMessage) -> int: + tokens = self._tokens_per_request_message raw_message = message.raw_message tokens += _process_raw_message( raw_message=raw_message, - tokens_per_name=self.tokens_per_name, - calculate_text_tokens=self.calculate_text_tokens, + tokens_per_name=self._tokens_per_request_message_name, + calculate_text_tokens=self.tokenize_text, handle_custom_content_part=lambda content_part: None, ) diff --git a/poetry.lock b/poetry.lock index 069dc3c..2094e73 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aidial-sdk" -version = "0.13.0" +version = "0.15.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "aidial_sdk-0.13.0-py3-none-any.whl", hash = "sha256:35784f12367e43f4540d67bab7b18315832e313517e02e969068d7ff2de3d69e"}, - {file = "aidial_sdk-0.13.0.tar.gz", hash = "sha256:c895c22d95d1c1954e170ebda3f5010e80cd47ed8b7225d375d1da01f67962e5"}, + {file = "aidial_sdk-0.15.0-py3-none-any.whl", hash = "sha256:7b9b3e5ec9688be2919dcd7dd0312aac807dc7917393ee5f846332713ad2e26a"}, + {file = "aidial_sdk-0.15.0.tar.gz", hash = "sha256:6b47bb36e8c795300e0d4b61308c6a2f86b59abb97905390a02789b343460720"}, ] [package.dependencies] @@ -2568,4 +2568,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "fe943703ce6ad346420d88035573a9334758f2259a1301ec4867a9bd9b16390f" +content-hash = "307c14e21b2fc8b1598cd8b903489c48c6d6367a44bfb6e764d02a45bc6dd9fb" diff --git a/pyproject.toml b/pyproject.toml index 3b5df36..ba3e9d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ aiohttp = "^3.10.11" numpy = "^1.26.0" pillow = "^10.3.0" azure-identity = "^1.16.1" -aidial-sdk = {version = "^0.13.0", extras = ["telemetry"]} +aidial-sdk = {version = "^0.15.0", extras = ["telemetry"]} [tool.poetry.group.test.dependencies] pytest = "7.4.0" diff --git a/tests/test_discard_messages.py b/tests/test_discard_messages.py index aa98e71..742bb57 100644 --- a/tests/test_discard_messages.py +++ b/tests/test_discard_messages.py @@ -141,7 +141,7 @@ def test_discarded_messages_without_error( ): tokenizer = PlainTextTokenizer(model="gpt-4") truncated_messages, discarded_messages, _used_tokens = ( - plain_text_truncate_prompt(messages, max_prompt_tokens, tokenizer) + plain_text_truncate_prompt({}, messages, max_prompt_tokens, tokenizer) ) assert (truncated_messages, discarded_messages) == response @@ -157,5 +157,5 @@ def test_discarded_messages_with_error( tokenizer = PlainTextTokenizer(model="gpt-4") with pytest.raises(DialException) as e_info: - plain_text_truncate_prompt(messages, max_prompt_tokens, tokenizer) + plain_text_truncate_prompt({}, messages, max_prompt_tokens, tokenizer) assert e_info.value.message == error_message diff --git a/tests/test_multimodal_truncate.py b/tests/test_multimodal_truncate.py index 031ae3c..3db888c 100644 --- a/tests/test_multimodal_truncate.py +++ b/tests/test_multimodal_truncate.py @@ -45,7 +45,7 @@ def test_multimodal_truncate_with_system_and_last_user_error(): ), ] with pytest.raises(TruncatePromptSystemAndLastUserError): - multi_modal_truncate_prompt(transformations, 15, 0, tokenizer) + multi_modal_truncate_prompt({}, transformations, 15, tokenizer) def test_multimodal_truncate_with_system_error(): @@ -57,7 +57,7 @@ def test_multimodal_truncate_with_system_error(): ), ] with pytest.raises(TruncatePromptSystemError): - multi_modal_truncate_prompt(transformations, 9, 3, tokenizer) + multi_modal_truncate_prompt({}, transformations, 9, tokenizer) @pytest.mark.parametrize( @@ -194,9 +194,9 @@ def test_multimodal_truncate( ): truncated, actual_discarded_messages, actual_used_tokens = ( multi_modal_truncate_prompt( + {}, transformations, max_prompt_tokens, - initial_prompt_tokens=3, tokenizer=tokenizer, ) )