Skip to content

Commit

Permalink
fix: reusing usage reported by the upstream OpenAI via include_usage …
Browse files Browse the repository at this point in the history
…option (#151)
  • Loading branch information
adubovik authored Oct 1, 2024
1 parent 6ef02ca commit 6ec0a79
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 67 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Copy `.env.example` to `.env` and customize it for your environment:
|AZURE_OPEN_AI_SCOPE|https://cognitiveservices.azure.com/.default|Provided scope of access token to Azure OpenAI services|
|API_VERSIONS_MAPPING|`{}`|The mapping of versions API for requests to Azure OpenAI API. Example: `{"2023-03-15-preview": "2023-05-15", "": "2024-02-15-preview"}`. An empty key sets the default api version for the case when the user didn't pass it in the request|
|DALLE3_AZURE_API_VERSION|2024-02-01|The version API for requests to Azure DALL-E-3 API|
|FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS|False|Fixes issue with receiving the first chunk with an empty list of choices|
|FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS|False|When enabled, the response stream is guaranteed to exclude chunks with an empty list of choices. This is useful when a DIAL client doesn't support such chunks. An empty list of choices can be generated by Azure OpenAI in at least two cases: (1) when the **Content filter** is not disabled, Azure includes [prompt filter results](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter?tabs=warning%2Cuser-prompt%2Cpython-new#prompt-annotation-message) in the first chunk with an empty list of choices; (2) when `stream_options.include_usage` is enabled, the last chunk contains usage data and an empty list of choices.|
|CORE_API_VERSION||Supported value `0.6` to work with the old version of the file api|
|MISTRAL_DEPLOYMENTS|``|Comma-separated list of deployments that support Mistral Large Azure API. Example: `mistral-large-azure,mistral-large`|
|DATABRICKS_DEPLOYMENTS|``|Comma-separated list of Databricks chat completion deployments. Example: `databricks-dbrx-instruct,databricks-mixtral-8x7b-instruct,databricks-llama-2-70b-chat`|
Expand Down
29 changes: 29 additions & 0 deletions aidial_adapter_openai/utils/merge_chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import TypeVar

from aidial_sdk.utils.merge_chunks import merge

_Chunk = TypeVar("_Chunk", bound=dict)


def merge_chunks(*chunks: _Chunk) -> _Chunk:
"""
The recursive merging procedure that avoids merging top-level atomic fields
(e.g. "id", "created", "model", "object", "system_fingerprint") and
instead chooses an _override_ merging strategy for such fields.
Non-atomic fields (e.g. "choice", "usage") are merged following
the standard recursive merging procedure.
"""

assert len(chunks) > 0, "At least one chunk must be provided"

target = chunks[0]

for chunk in chunks[1:]:
source = chunk.copy()
for key, value in list(source.items()):
if not isinstance(value, (list, dict)) and value is not None:
target[key] = value
del source[key]
target = merge(target, source)

return target
144 changes: 86 additions & 58 deletions aidial_adapter_openai/utils/streaming.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import logging
from time import time
from typing import Any, AsyncIterator, Callable, Optional, TypeVar
from typing import Any, AsyncIterator, Callable, Iterable, Optional, TypeVar
from uuid import uuid4

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.utils.merge_chunks import merge
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import APIError, APIStatusError
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from aidial_adapter_openai.utils.env import get_env_bool
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.merge_chunks import merge_chunks
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream

fix_streaming_issues_in_new_api_versions = get_env_bool(
ELIMINATE_EMPTY_CHOICES = get_env_bool(
"FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS", False
)

Expand Down Expand Up @@ -61,84 +61,112 @@ async def generate_stream(
stream: AsyncIterator[dict],
) -> AsyncIterator[dict]:

last_chunk, temp_chunk = None, None
stream_finished = False
total_content = ""

def finalize_finish_chunk(chunk: dict) -> None:
"""
Adding additional information to a chunk that has non-null finish_reason field.
"""

if not chunk.get("usage"):
completion_tokens = tokenize(total_content)
prompt_tokens = get_prompt_tokens()
chunk["usage"] = {
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
if discarded_messages is not None:
chunk["statistics"] = {"discarded_messages": discarded_messages}
noop_chunk = build_chunk(
id=generate_id(),
created=generate_created(),
model=deployment,
is_stream=True,
message={},
finish_reason=None,
)

def set_usage(chunk: dict | None, completions: Iterable[str]) -> dict:
chunk = chunk or noop_chunk
completion_tokens = sum(map(tokenize, completions))
prompt_tokens = get_prompt_tokens()
chunk["usage"] = {
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return chunk

def set_finish_reason(chunk: dict | None, finish_reason: str) -> dict:
chunk = chunk or noop_chunk
chunk["choices"] = chunk.get("choices") or [{"index": 0, "delta": {}}]
chunk["choices"][0]["finish_reason"] = finish_reason
return chunk

def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict:
chunk = chunk or noop_chunk
chunk["statistics"] = {"discarded_messages": indices}
return chunk

n_chunks = 0
last_chunk = None
buffer_chunk = None

completions: dict[int, str] = {}
found_finish_reason = False
found_usage = False
error = None

try:
async for chunk in stream:
if len(chunk["choices"]) > 0:
if temp_chunk is not None:
chunk = merge(temp_chunk, chunk)
temp_chunk = None
n_chunks += 1

if buffer_chunk is not None:
chunk = merge_chunks(buffer_chunk, chunk)
buffer_chunk = None

choices = chunk.get("choices") or []

choice = chunk["choices"][0]
total_content += (choice.get("delta") or {}).get(
"content"
) or ""
for choice in choices:
index = choice["index"]
content = (choice.get("delta") or {}).get("content") or ""

if choice["finish_reason"] is not None:
stream_finished = True
finalize_finish_chunk(chunk)
completions[index] = completions.get(index, "") + content
found_finish_reason |= bool(choice.get("finish_reason"))

yield chunk
found_usage |= bool(chunk.get("usage"))

# Azure OpenAI returns an empty list of choices as a first chunk
# when content filtering is enabled for a corresponding deployment.
# The safety rating of the request is reported in this first chunk.
# Here we withhold such a chunk and merge it later with a follow-up chunk.
if len(choices) == 0 and ELIMINATE_EMPTY_CHOICES:
buffer_chunk = chunk
else:
if fix_streaming_issues_in_new_api_versions:
temp_chunk = chunk
else:
yield chunk
if last_chunk is not None:
yield last_chunk
last_chunk = chunk

last_chunk = chunk
except APIError as e:
status_code = e.status_code if isinstance(e, APIStatusError) else 500
yield DialException(
error = DialException(
status_code=status_code,
message=e.message,
type=e.type,
param=e.param,
code=e.code,
).json_error()
return

if not stream_finished:
if last_chunk is None:
if last_chunk is not None and buffer_chunk is not None:
last_chunk = merge_chunks(buffer_chunk, last_chunk)

if discarded_messages is not None:
last_chunk = set_discarded_messages(last_chunk, discarded_messages)

if not found_usage and (not error or completions):
last_chunk = set_usage(last_chunk, completions.values())

if not error:
if n_chunks == 0:
logger.warning("Received 0 chunks")
else:
elif not found_finish_reason:
logger.warning("Didn't receive chunk with the finish reason")

last_chunk = last_chunk or {}
id = last_chunk.get("id") or generate_id()
created = last_chunk.get("created") or generate_created()
model = last_chunk.get("model") or deployment
if not found_finish_reason:
last_chunk = set_finish_reason(last_chunk, "length")

finish_chunk = build_chunk(
id=id,
created=created,
model=model,
is_stream=True,
message={},
finish_reason="length",
)
if not found_usage:
last_chunk = set_usage(last_chunk, completions.values())

finalize_finish_chunk(finish_chunk)
if last_chunk:
yield last_chunk

yield finish_chunk
if error:
yield error


def create_stage_chunk(name: str, content: str, stream: bool) -> dict:
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from unittest.mock import patch

import httpx
import pytest
import pytest_asyncio
from httpx import ASGITransport

from aidial_adapter_openai.app import app


@pytest.fixture
def eliminate_empty_choices():
with patch(
"aidial_adapter_openai.utils.streaming.ELIMINATE_EMPTY_CHOICES", True
):
yield


@pytest_asyncio.fixture
async def test_app():
async with httpx.AsyncClient(
Expand Down
28 changes: 21 additions & 7 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,17 @@ async def test_error_during_streaming_unfinished(test_app: httpx.AsyncClient):
)

assert response.status_code == 200
mock_stream.assert_response_content(response, assert_equal)
mock_stream.assert_response_content(
response,
assert_equal,
usages={
0: {
"completion_tokens": 2,
"prompt_tokens": 9,
"total_tokens": 11,
}
},
)


@respx.mock
Expand Down Expand Up @@ -309,13 +319,17 @@ async def test_interrupted_stream(test_app: httpx.AsyncClient):

assert response.status_code == 200

expected_final_chunk = single_choice_chunk(
delta={},
finish_reason="length",
usage={"completion_tokens": 1, "prompt_tokens": 9, "total_tokens": 10},
expected_stream = OpenAIStream(
single_choice_chunk(
delta={"role": "assistant", "content": "hello"},
finish_reason="length",
usage={
"completion_tokens": 1,
"prompt_tokens": 9,
"total_tokens": 10,
},
)
)

expected_stream = OpenAIStream(*mock_stream.chunks, expected_final_chunk)
expected_stream.assert_response_content(response, assert_equal)


Expand Down
Loading

0 comments on commit 6ec0a79

Please sign in to comment.