diff --git a/aidial_adapter_openai/completions.py b/aidial_adapter_openai/completions.py index 6fa182b..0c1789e 100644 --- a/aidial_adapter_openai/completions.py +++ b/aidial_adapter_openai/completions.py @@ -34,7 +34,7 @@ def convert_to_chat_completions_response( "content": sanitize_text(chunk.choices[0].text), "role": "assistant", }, - created=str(chunk.created), + created=chunk.created, is_stream=is_stream, usage=chunk.usage.to_dict() if chunk.usage else None, ) diff --git a/aidial_adapter_openai/dalle3.py b/aidial_adapter_openai/dalle3.py index 56e4e65..ab8125e 100644 --- a/aidial_adapter_openai/dalle3.py +++ b/aidial_adapter_openai/dalle3.py @@ -68,7 +68,7 @@ def build_custom_content(base64_image: str, revised_prompt: str) -> Any: async def generate_stream( - id: str, created: str, custom_content: Any + id: str, created: int, custom_content: Any ) -> AsyncIterator[dict]: yield build_chunk(id, None, {"role": "assistant"}, created, True) yield build_chunk(id, None, custom_content, created, True) diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index f7d4d30..b429a67 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -20,18 +20,22 @@ ) -def generate_id(): +def generate_id() -> str: return "chatcmpl-" + str(uuid4()) +def generate_created() -> int: + return int(time()) + + def build_chunk( id: str, finish_reason: Optional[str], message: Any, - created: str, + created: int, is_stream: bool, **extra, -): +) -> dict: message_key = "delta" if is_stream else "message" object_name = "chat.completion.chunk" if is_stream else "chat.completion" @@ -120,7 +124,7 @@ def finalize_finish_chunk(chunk: dict) -> None: last_chunk = last_chunk or {} id = last_chunk.get("id") or generate_id() - created = last_chunk.get("created") or str(int(time())) + created = last_chunk.get("created") or generate_created() model = last_chunk.get("model") or deployment finish_chunk = build_chunk( @@ -139,7 +143,7 @@ def finalize_finish_chunk(chunk: dict) -> None: def create_stage_chunk(name: str, content: str, stream: bool) -> dict: id = generate_id() - created = str(int(time())) + created = generate_created() stage = { "index": 0, diff --git a/pyproject.toml b/pyproject.toml index 4701515..972c658 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,9 +51,10 @@ flake8 = "6.0.0" nox = "^2023.4.22" [tool.pytest.ini_options] -# muting warnings coming from opentelemetry package +# muting warnings coming from opentelemetry and pkg_resources packages filterwarnings = [ - "ignore::DeprecationWarning:opentelemetry.instrumentation.dependencies" + "ignore::DeprecationWarning:opentelemetry.instrumentation.dependencies", + "ignore::DeprecationWarning:pkg_resources" ] [tool.pyright] diff --git a/tests/test_errors.py b/tests/test_errors.py index 600229a..08ef9ee 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,18 +1,27 @@ import json -from typing import Callable +from typing import Any, Callable import httpx import pytest import respx from respx.types import SideEffectTypes +from tests.utils.dictionary import exclude_keys from tests.utils.stream import OpenAIStream, single_choice_chunk -def assert_equal(actual, expected): +def assert_equal(actual: Any, expected: Any): assert actual == expected +def assert_equal_no_dynamic_fields(actual: Any, expected: Any): + if isinstance(actual, dict) and isinstance(expected, dict): + keys = {"id", "created"} + assert exclude_keys(actual, keys) == exclude_keys(expected, keys) + else: + assert actual == expected + + def mock_response( status_code: int, content_type: str, @@ -293,6 +302,45 @@ async def test_interrupted_stream(test_app: httpx.AsyncClient): expected_stream.assert_response_content(response, assert_equal) +@respx.mock +@pytest.mark.asyncio +async def test_zero_chunk_stream(test_app: httpx.AsyncClient): + mock_stream = OpenAIStream() + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=200, + content_type="text/event-stream", + content=mock_stream.to_content(), + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={ + "messages": [{"role": "user", "content": "Test content"}], + "stream": True, + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + + expected_final_chunk = single_choice_chunk( + delta={}, + finish_reason="length", + usage={"prompt_tokens": 9, "completion_tokens": 0, "total_tokens": 9}, + ) + + expected_stream = OpenAIStream(expected_final_chunk) + expected_stream.assert_response_content( + response, assert_equal_no_dynamic_fields + ) + + @respx.mock @pytest.mark.asyncio async def test_incorrect_upstream_url(test_app: httpx.AsyncClient): diff --git a/tests/utils/dictionary.py b/tests/utils/dictionary.py new file mode 100644 index 0000000..e5f7b87 --- /dev/null +++ b/tests/utils/dictionary.py @@ -0,0 +1,5 @@ +from typing import Iterable + + +def exclude_keys(d: dict, keys: Iterable[str]) -> dict: + return {k: v for k, v in d.items() if k not in keys} diff --git a/tests/utils/stream.py b/tests/utils/stream.py index 1e00a73..0b23b26 100644 --- a/tests/utils/stream.py +++ b/tests/utils/stream.py @@ -48,7 +48,7 @@ def assert_response_content( def chunk( *, id: str = "chatcmpl-test", - created: str = "1695940483", + created: int = 1695940483, model: str = "gpt-4", choices: List[dict], usage: dict | None = None, @@ -68,7 +68,7 @@ def chunk( def single_choice_chunk( *, id: str = "chatcmpl-test", - created: str = "1695940483", + created: int = 1695940483, model: str = "gpt-4", finish_reason: str | None = None, delta: dict = {},