Skip to content

Commit

Permalink
fix: changed type of created field in chunk from str to int (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Aug 9, 2024
1 parent c9a3fe5 commit b64c4cc
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 13 deletions.
2 changes: 1 addition & 1 deletion aidial_adapter_openai/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def convert_to_chat_completions_response(
"content": sanitize_text(chunk.choices[0].text),
"role": "assistant",
},
created=str(chunk.created),
created=chunk.created,
is_stream=is_stream,
usage=chunk.usage.to_dict() if chunk.usage else None,
)
Expand Down
2 changes: 1 addition & 1 deletion aidial_adapter_openai/dalle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def build_custom_content(base64_image: str, revised_prompt: str) -> Any:


async def generate_stream(
id: str, created: str, custom_content: Any
id: str, created: int, custom_content: Any
) -> AsyncIterator[dict]:
yield build_chunk(id, None, {"role": "assistant"}, created, True)
yield build_chunk(id, None, custom_content, created, True)
Expand Down
14 changes: 9 additions & 5 deletions aidial_adapter_openai/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@
)


def generate_id():
def generate_id() -> str:
return "chatcmpl-" + str(uuid4())


def generate_created() -> int:
return int(time())


def build_chunk(
id: str,
finish_reason: Optional[str],
message: Any,
created: str,
created: int,
is_stream: bool,
**extra,
):
) -> dict:
message_key = "delta" if is_stream else "message"
object_name = "chat.completion.chunk" if is_stream else "chat.completion"

Expand Down Expand Up @@ -120,7 +124,7 @@ def finalize_finish_chunk(chunk: dict) -> None:

last_chunk = last_chunk or {}
id = last_chunk.get("id") or generate_id()
created = last_chunk.get("created") or str(int(time()))
created = last_chunk.get("created") or generate_created()
model = last_chunk.get("model") or deployment

finish_chunk = build_chunk(
Expand All @@ -139,7 +143,7 @@ def finalize_finish_chunk(chunk: dict) -> None:

def create_stage_chunk(name: str, content: str, stream: bool) -> dict:
id = generate_id()
created = str(int(time()))
created = generate_created()

stage = {
"index": 0,
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ flake8 = "6.0.0"
nox = "^2023.4.22"

[tool.pytest.ini_options]
# muting warnings coming from opentelemetry package
# muting warnings coming from opentelemetry and pkg_resources packages
filterwarnings = [
"ignore::DeprecationWarning:opentelemetry.instrumentation.dependencies"
"ignore::DeprecationWarning:opentelemetry.instrumentation.dependencies",
"ignore::DeprecationWarning:pkg_resources"
]

[tool.pyright]
Expand Down
52 changes: 50 additions & 2 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import json
from typing import Callable
from typing import Any, Callable

import httpx
import pytest
import respx
from respx.types import SideEffectTypes

from tests.utils.dictionary import exclude_keys
from tests.utils.stream import OpenAIStream, single_choice_chunk


def assert_equal(actual, expected):
def assert_equal(actual: Any, expected: Any):
assert actual == expected


def assert_equal_no_dynamic_fields(actual: Any, expected: Any):
if isinstance(actual, dict) and isinstance(expected, dict):
keys = {"id", "created"}
assert exclude_keys(actual, keys) == exclude_keys(expected, keys)
else:
assert actual == expected


def mock_response(
status_code: int,
content_type: str,
Expand Down Expand Up @@ -293,6 +302,45 @@ async def test_interrupted_stream(test_app: httpx.AsyncClient):
expected_stream.assert_response_content(response, assert_equal)


@respx.mock
@pytest.mark.asyncio
async def test_zero_chunk_stream(test_app: httpx.AsyncClient):
mock_stream = OpenAIStream()

respx.post(
"http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview"
).respond(
status_code=200,
content_type="text/event-stream",
content=mock_stream.to_content(),
)

response = await test_app.post(
"/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
json={
"messages": [{"role": "user", "content": "Test content"}],
"stream": True,
},
headers={
"X-UPSTREAM-KEY": "TEST_API_KEY",
"X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions",
},
)

assert response.status_code == 200

expected_final_chunk = single_choice_chunk(
delta={},
finish_reason="length",
usage={"prompt_tokens": 9, "completion_tokens": 0, "total_tokens": 9},
)

expected_stream = OpenAIStream(expected_final_chunk)
expected_stream.assert_response_content(
response, assert_equal_no_dynamic_fields
)


@respx.mock
@pytest.mark.asyncio
async def test_incorrect_upstream_url(test_app: httpx.AsyncClient):
Expand Down
5 changes: 5 additions & 0 deletions tests/utils/dictionary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Iterable


def exclude_keys(d: dict, keys: Iterable[str]) -> dict:
return {k: v for k, v in d.items() if k not in keys}
4 changes: 2 additions & 2 deletions tests/utils/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def assert_response_content(
def chunk(
*,
id: str = "chatcmpl-test",
created: str = "1695940483",
created: int = 1695940483,
model: str = "gpt-4",
choices: List[dict],
usage: dict | None = None,
Expand All @@ -68,7 +68,7 @@ def chunk(
def single_choice_chunk(
*,
id: str = "chatcmpl-test",
created: str = "1695940483",
created: int = 1695940483,
model: str = "gpt-4",
finish_reason: str | None = None,
delta: dict = {},
Expand Down

0 comments on commit b64c4cc

Please sign in to comment.