From c9a3fe5c2deb05483ff6279dcd2febc994f713f6 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Thu, 8 Aug 2024 13:12:19 +0100 Subject: [PATCH] feat: avoid computing token usage if upstream model has reported it already (#138) --- aidial_adapter_openai/completions.py | 2 +- aidial_adapter_openai/gpt.py | 13 +- .../gpt4_multi_modal/chat_completion.py | 8 +- aidial_adapter_openai/utils/streaming.py | 104 ++++++++-------- tests/test_errors.py | 112 ++++++++++++------ tests/test_streaming.py | 88 +++++++------- tests/utils/stream.py | 46 +++++++ 7 files changed, 232 insertions(+), 141 deletions(-) diff --git a/aidial_adapter_openai/completions.py b/aidial_adapter_openai/completions.py index b6b7704..6fa182b 100644 --- a/aidial_adapter_openai/completions.py +++ b/aidial_adapter_openai/completions.py @@ -30,7 +30,7 @@ def convert_to_chat_completions_response( converted_chunk = build_chunk( id=chunk.id, finish_reason=chunk.choices[0].finish_reason, - delta={ + message={ "content": sanitize_text(chunk.choices[0].text), "role": "assistant", }, diff --git a/aidial_adapter_openai/gpt.py b/aidial_adapter_openai/gpt.py index 5798653..43072f6 100644 --- a/aidial_adapter_openai/gpt.py +++ b/aidial_adapter_openai/gpt.py @@ -56,15 +56,16 @@ async def gpt_chat_completion( if isinstance(response, AsyncStream): - prompt_tokens = tokenizer.calculate_prompt_tokens(data["messages"]) return StreamingResponse( to_openai_sse_stream( generate_stream( - prompt_tokens, - map_stream(chunk_to_dict, response), - tokenizer, - deployment_id, - discarded_messages, + get_prompt_tokens=lambda: tokenizer.calculate_prompt_tokens( + data["messages"] + ), + tokenize=tokenizer.calculate_tokens, + deployment=deployment_id, + discarded_messages=discarded_messages, + stream=map_stream(chunk_to_dict, response), ), ), media_type="text/event-stream", diff --git a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py index 554236a..081be7d 100644 --- a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py +++ b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py @@ -232,14 +232,14 @@ def debug_print(chunk: T) -> T: map_stream( debug_print, generate_stream( + get_prompt_tokens=lambda: estimated_prompt_tokens, + tokenize=tokenizer.calculate_tokens, + deployment=deployment, + discarded_messages=None, stream=map_stream( response_transformer, parse_openai_sse_stream(response), ), - prompt_tokens=estimated_prompt_tokens, - tokenizer=tokenizer, - deployment=deployment, - discarded_messages=None, ), ) ), diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index 2bbdffe..f7d4d30 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -14,7 +14,6 @@ from aidial_adapter_openai.utils.errors import dial_exception_to_json_error from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream -from aidial_adapter_openai.utils.tokens import Tokenizer fix_streaming_issues_in_new_api_versions = get_env_bool( "FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS", False @@ -28,21 +27,22 @@ def generate_id(): def build_chunk( id: str, finish_reason: Optional[str], - delta: Any, + message: Any, created: str, - is_stream, + is_stream: bool, **extra, ): - choice_content_key = "delta" if is_stream else "message" + message_key = "delta" if is_stream else "message" + object_name = "chat.completion.chunk" if is_stream else "chat.completion" return { "id": id, - "object": "chat.completion.chunk" if is_stream else "chat.completion", + "object": object_name, "created": created, "choices": [ { "index": 0, - choice_content_key: delta, + message_key: message, "finish_reason": finish_reason, } ], @@ -51,17 +51,35 @@ def build_chunk( async def generate_stream( - prompt_tokens: int, - stream: AsyncIterator[dict], - tokenizer: Tokenizer, + *, + get_prompt_tokens: Callable[[], int], + tokenize: Callable[[str], int], deployment: str, discarded_messages: Optional[list[int]], + stream: AsyncIterator[dict], ) -> AsyncIterator[dict]: + last_chunk, temp_chunk = None, None stream_finished = False + total_content = "" + + def finalize_finish_chunk(chunk: dict) -> None: + """ + Adding additional information to a chunk that has non-null finish_reason field. + """ + + if not chunk.get("usage"): + completion_tokens = tokenize(total_content) + prompt_tokens = get_prompt_tokens() + chunk["usage"] = { + "completion_tokens": completion_tokens, + "prompt_tokens": prompt_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + if discarded_messages is not None: + chunk["statistics"] = {"discarded_messages": discarded_messages} try: - total_content = "" async for chunk in stream: if len(chunk["choices"]) > 0: if temp_chunk is not None: @@ -69,24 +87,13 @@ async def generate_stream( temp_chunk = None choice = chunk["choices"][0] - - content = (choice.get("delta") or {}).get("content") or "" - total_content += content + total_content += (choice.get("delta") or {}).get( + "content" + ) or "" if choice["finish_reason"] is not None: stream_finished = True - completion_tokens = tokenizer.calculate_tokens( - total_content - ) - chunk["usage"] = { - "completion_tokens": completion_tokens, - "prompt_tokens": prompt_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - if discarded_messages is not None: - chunk["statistics"] = { - "discarded_messages": discarded_messages - } + finalize_finish_chunk(chunk) yield chunk else: @@ -106,39 +113,28 @@ async def generate_stream( return if not stream_finished: - if last_chunk is not None: + if last_chunk is None: + logger.warning("Received 0 chunks") + else: logger.warning("Didn't receive chunk with the finish reason") - completion_tokens = tokenizer.calculate_tokens(total_content) - last_chunk["usage"] = { - "completion_tokens": completion_tokens, - "prompt_tokens": prompt_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - last_chunk["choices"][0]["delta"]["content"] = "" - last_chunk["choices"][0]["finish_reason"] = "length" + last_chunk = last_chunk or {} + id = last_chunk.get("id") or generate_id() + created = last_chunk.get("created") or str(int(time())) + model = last_chunk.get("model") or deployment + + finish_chunk = build_chunk( + id=id, + created=created, + model=model, + is_stream=True, + message={}, + finish_reason="length", + ) - yield last_chunk - else: - logger.warning("Received 0 chunks") + finalize_finish_chunk(finish_chunk) - id = generate_id() - created = str(int(time())) - is_stream = True - - yield build_chunk( - id, - "length", - {}, - created, - is_stream, - model=deployment, - usage={ - "completion_tokens": 0, - "prompt_tokens": prompt_tokens, - "total_tokens": prompt_tokens, - }, - ) + yield finish_chunk def create_stage_chunk(name: str, content: str, stream: bool) -> dict: diff --git a/tests/test_errors.py b/tests/test_errors.py index 7840b02..600229a 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -6,7 +6,7 @@ import respx from respx.types import SideEffectTypes -from tests.utils.stream import OpenAIStream +from tests.utils.stream import OpenAIStream, single_choice_chunk def assert_equal(actual, expected): @@ -37,18 +37,9 @@ async def test_single_chunk_token_counting(test_app: httpx.AsyncClient): # and passes it further to the upstream endpoint. mock_stream = OpenAIStream( - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1695940483, - "choices": [ - { - "index": 0, - "finish_reason": "stop", - "delta": {"role": "assistant", "content": "5"}, - } - ], - }, + single_choice_chunk( + delta={"role": "assistant", "content": "5"}, finish_reason="stop" + ), ) respx.post( @@ -191,28 +182,10 @@ async def test_missing_api_version(test_app: httpx.AsyncClient): @respx.mock @pytest.mark.asyncio -async def test_error_during_streaming(test_app: httpx.AsyncClient): +async def test_error_during_streaming_stopped(test_app: httpx.AsyncClient): mock_stream = OpenAIStream( - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1695940483, - "model": "gpt-4", - "choices": [ - { - "index": 0, - "finish_reason": "stop", - "delta": {"role": "assistant"}, - } - ], - "usage": None, - }, - { - "error": { - "message": "Error test", - "type": "runtime_error", - } - }, + single_choice_chunk(finish_reason="stop", delta={"role": "assistant"}), + {"error": {"message": "Error test", "type": "runtime_error"}}, ) respx.post( @@ -249,6 +222,77 @@ async def test_error_during_streaming(test_app: httpx.AsyncClient): ) +@respx.mock +@pytest.mark.asyncio +async def test_error_during_streaming_unfinished(test_app: httpx.AsyncClient): + mock_stream = OpenAIStream( + single_choice_chunk(delta={"role": "assistant", "content": "hello "}), + {"error": {"message": "Error test", "type": "runtime_error"}}, + ) + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=200, + content_type="text/event-stream", + content=mock_stream.to_content(), + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={ + "messages": [{"role": "user", "content": "Test content"}], + "stream": True, + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + mock_stream.assert_response_content(response, assert_equal) + + +@respx.mock +@pytest.mark.asyncio +async def test_interrupted_stream(test_app: httpx.AsyncClient): + mock_stream = OpenAIStream( + single_choice_chunk(delta={"role": "assistant", "content": "hello"}), + ) + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=200, + content_type="text/event-stream", + content=mock_stream.to_content(), + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={ + "messages": [{"role": "user", "content": "Test content"}], + "stream": True, + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + + expected_final_chunk = single_choice_chunk( + delta={}, + finish_reason="length", + usage={"completion_tokens": 1, "prompt_tokens": 9, "total_tokens": 10}, + ) + + expected_stream = OpenAIStream(*mock_stream.chunks, expected_final_chunk) + expected_stream.assert_response_content(response, assert_equal) + + @respx.mock @pytest.mark.asyncio async def test_incorrect_upstream_url(test_app: httpx.AsyncClient): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 41e8c74..256dc43 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -2,7 +2,7 @@ import pytest import respx -from tests.utils.stream import OpenAIStream +from tests.utils.stream import OpenAIStream, single_choice_chunk def assert_equal(actual, expected): @@ -11,48 +11,11 @@ def assert_equal(actual, expected): @respx.mock @pytest.mark.asyncio -async def test_streaming(test_app: httpx.AsyncClient): +async def test_streaming_computed_tokens(test_app: httpx.AsyncClient): mock_stream = OpenAIStream( - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1695940483, - "model": "gpt-4", - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": { - "role": "assistant", - }, - } - ], - "usage": None, - }, - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1695940483, - "model": "gpt-4", - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": { - "content": "Test content", - }, - } - ], - "usage": None, - }, - { - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1696245654, - "model": "gpt-4", - "choices": [{"index": 0, "finish_reason": "stop", "delta": {}}], - "usage": None, - }, + single_choice_chunk(delta={"role": "assistant"}), + single_choice_chunk(delta={"content": "Test content"}), + single_choice_chunk(delta={}, finish_reason="stop"), ) respx.post( @@ -87,3 +50,44 @@ async def test_streaming(test_app: httpx.AsyncClient): } }, ) + + +@respx.mock +@pytest.mark.asyncio +async def test_streaming_inherited_tokens(test_app: httpx.AsyncClient): + mock_stream = OpenAIStream( + single_choice_chunk(delta={"role": "assistant"}), + single_choice_chunk(delta={"content": "Test content"}), + single_choice_chunk( + delta={}, + finish_reason="stop", + usage={ + "completion_tokens": 111, + "prompt_tokens": 222, + "total_tokens": 333, + }, + ), + ) + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-06-15" + ).respond( + status_code=200, + content=mock_stream.to_content(), + content_type="text/event-stream", + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-06-15", + json={ + "messages": [{"role": "user", "content": "Test content"}], + "stream": True, + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + mock_stream.assert_response_content(response, assert_equal) diff --git a/tests/utils/stream.py b/tests/utils/stream.py index 2c046d3..1e00a73 100644 --- a/tests/utils/stream.py +++ b/tests/utils/stream.py @@ -43,3 +43,49 @@ def assert_response_content( assert False line_idx += 1 + + +def chunk( + *, + id: str = "chatcmpl-test", + created: str = "1695940483", + model: str = "gpt-4", + choices: List[dict], + usage: dict | None = None, + **kwargs, +) -> dict: + return { + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": choices, + "usage": usage, + **kwargs, + } + + +def single_choice_chunk( + *, + id: str = "chatcmpl-test", + created: str = "1695940483", + model: str = "gpt-4", + finish_reason: str | None = None, + delta: dict = {}, + usage: dict | None = None, + **kwargs, +) -> dict: + return chunk( + id=id, + created=created, + model=model, + choices=[ + { + "index": 0, + "finish_reason": finish_reason, + "delta": delta, + } + ], + usage=usage, + **kwargs, + )