Skip to content

Commit

Permalink
feat: avoid computing token usage if upstream model has reported it a…
Browse files Browse the repository at this point in the history
…lready (#138)
  • Loading branch information
adubovik authored Aug 8, 2024
1 parent b462d1c commit c9a3fe5
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 141 deletions.
2 changes: 1 addition & 1 deletion aidial_adapter_openai/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def convert_to_chat_completions_response(
converted_chunk = build_chunk(
id=chunk.id,
finish_reason=chunk.choices[0].finish_reason,
delta={
message={
"content": sanitize_text(chunk.choices[0].text),
"role": "assistant",
},
Expand Down
13 changes: 7 additions & 6 deletions aidial_adapter_openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ async def gpt_chat_completion(

if isinstance(response, AsyncStream):

prompt_tokens = tokenizer.calculate_prompt_tokens(data["messages"])
return StreamingResponse(
to_openai_sse_stream(
generate_stream(
prompt_tokens,
map_stream(chunk_to_dict, response),
tokenizer,
deployment_id,
discarded_messages,
get_prompt_tokens=lambda: tokenizer.calculate_prompt_tokens(
data["messages"]
),
tokenize=tokenizer.calculate_tokens,
deployment=deployment_id,
discarded_messages=discarded_messages,
stream=map_stream(chunk_to_dict, response),
),
),
media_type="text/event-stream",
Expand Down
8 changes: 4 additions & 4 deletions aidial_adapter_openai/gpt4_multi_modal/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,14 @@ def debug_print(chunk: T) -> T:
map_stream(
debug_print,
generate_stream(
get_prompt_tokens=lambda: estimated_prompt_tokens,
tokenize=tokenizer.calculate_tokens,
deployment=deployment,
discarded_messages=None,
stream=map_stream(
response_transformer,
parse_openai_sse_stream(response),
),
prompt_tokens=estimated_prompt_tokens,
tokenizer=tokenizer,
deployment=deployment,
discarded_messages=None,
),
)
),
Expand Down
104 changes: 50 additions & 54 deletions aidial_adapter_openai/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from aidial_adapter_openai.utils.errors import dial_exception_to_json_error
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.tokens import Tokenizer

fix_streaming_issues_in_new_api_versions = get_env_bool(
"FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS", False
Expand All @@ -28,21 +27,22 @@ def generate_id():
def build_chunk(
id: str,
finish_reason: Optional[str],
delta: Any,
message: Any,
created: str,
is_stream,
is_stream: bool,
**extra,
):
choice_content_key = "delta" if is_stream else "message"
message_key = "delta" if is_stream else "message"
object_name = "chat.completion.chunk" if is_stream else "chat.completion"

return {
"id": id,
"object": "chat.completion.chunk" if is_stream else "chat.completion",
"object": object_name,
"created": created,
"choices": [
{
"index": 0,
choice_content_key: delta,
message_key: message,
"finish_reason": finish_reason,
}
],
Expand All @@ -51,42 +51,49 @@ def build_chunk(


async def generate_stream(
prompt_tokens: int,
stream: AsyncIterator[dict],
tokenizer: Tokenizer,
*,
get_prompt_tokens: Callable[[], int],
tokenize: Callable[[str], int],
deployment: str,
discarded_messages: Optional[list[int]],
stream: AsyncIterator[dict],
) -> AsyncIterator[dict]:

last_chunk, temp_chunk = None, None
stream_finished = False
total_content = ""

def finalize_finish_chunk(chunk: dict) -> None:
"""
Adding additional information to a chunk that has non-null finish_reason field.
"""

if not chunk.get("usage"):
completion_tokens = tokenize(total_content)
prompt_tokens = get_prompt_tokens()
chunk["usage"] = {
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
if discarded_messages is not None:
chunk["statistics"] = {"discarded_messages": discarded_messages}

try:
total_content = ""
async for chunk in stream:
if len(chunk["choices"]) > 0:
if temp_chunk is not None:
chunk = merge(temp_chunk, chunk)
temp_chunk = None

choice = chunk["choices"][0]

content = (choice.get("delta") or {}).get("content") or ""
total_content += content
total_content += (choice.get("delta") or {}).get(
"content"
) or ""

if choice["finish_reason"] is not None:
stream_finished = True
completion_tokens = tokenizer.calculate_tokens(
total_content
)
chunk["usage"] = {
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
if discarded_messages is not None:
chunk["statistics"] = {
"discarded_messages": discarded_messages
}
finalize_finish_chunk(chunk)

yield chunk
else:
Expand All @@ -106,39 +113,28 @@ async def generate_stream(
return

if not stream_finished:
if last_chunk is not None:
if last_chunk is None:
logger.warning("Received 0 chunks")
else:
logger.warning("Didn't receive chunk with the finish reason")

completion_tokens = tokenizer.calculate_tokens(total_content)
last_chunk["usage"] = {
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
last_chunk["choices"][0]["delta"]["content"] = ""
last_chunk["choices"][0]["finish_reason"] = "length"
last_chunk = last_chunk or {}
id = last_chunk.get("id") or generate_id()
created = last_chunk.get("created") or str(int(time()))
model = last_chunk.get("model") or deployment

finish_chunk = build_chunk(
id=id,
created=created,
model=model,
is_stream=True,
message={},
finish_reason="length",
)

yield last_chunk
else:
logger.warning("Received 0 chunks")
finalize_finish_chunk(finish_chunk)

id = generate_id()
created = str(int(time()))
is_stream = True

yield build_chunk(
id,
"length",
{},
created,
is_stream,
model=deployment,
usage={
"completion_tokens": 0,
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
)
yield finish_chunk


def create_stage_chunk(name: str, content: str, stream: bool) -> dict:
Expand Down
112 changes: 78 additions & 34 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import respx
from respx.types import SideEffectTypes

from tests.utils.stream import OpenAIStream
from tests.utils.stream import OpenAIStream, single_choice_chunk


def assert_equal(actual, expected):
Expand Down Expand Up @@ -37,18 +37,9 @@ async def test_single_chunk_token_counting(test_app: httpx.AsyncClient):
# and passes it further to the upstream endpoint.

mock_stream = OpenAIStream(
{
"id": "chatcmpl-test",
"object": "chat.completion.chunk",
"created": 1695940483,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"delta": {"role": "assistant", "content": "5"},
}
],
},
single_choice_chunk(
delta={"role": "assistant", "content": "5"}, finish_reason="stop"
),
)

respx.post(
Expand Down Expand Up @@ -191,28 +182,10 @@ async def test_missing_api_version(test_app: httpx.AsyncClient):

@respx.mock
@pytest.mark.asyncio
async def test_error_during_streaming(test_app: httpx.AsyncClient):
async def test_error_during_streaming_stopped(test_app: httpx.AsyncClient):
mock_stream = OpenAIStream(
{
"id": "chatcmpl-test",
"object": "chat.completion.chunk",
"created": 1695940483,
"model": "gpt-4",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"delta": {"role": "assistant"},
}
],
"usage": None,
},
{
"error": {
"message": "Error test",
"type": "runtime_error",
}
},
single_choice_chunk(finish_reason="stop", delta={"role": "assistant"}),
{"error": {"message": "Error test", "type": "runtime_error"}},
)

respx.post(
Expand Down Expand Up @@ -249,6 +222,77 @@ async def test_error_during_streaming(test_app: httpx.AsyncClient):
)


@respx.mock
@pytest.mark.asyncio
async def test_error_during_streaming_unfinished(test_app: httpx.AsyncClient):
mock_stream = OpenAIStream(
single_choice_chunk(delta={"role": "assistant", "content": "hello "}),
{"error": {"message": "Error test", "type": "runtime_error"}},
)

respx.post(
"http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview"
).respond(
status_code=200,
content_type="text/event-stream",
content=mock_stream.to_content(),
)

response = await test_app.post(
"/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
json={
"messages": [{"role": "user", "content": "Test content"}],
"stream": True,
},
headers={
"X-UPSTREAM-KEY": "TEST_API_KEY",
"X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions",
},
)

assert response.status_code == 200
mock_stream.assert_response_content(response, assert_equal)


@respx.mock
@pytest.mark.asyncio
async def test_interrupted_stream(test_app: httpx.AsyncClient):
mock_stream = OpenAIStream(
single_choice_chunk(delta={"role": "assistant", "content": "hello"}),
)

respx.post(
"http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview"
).respond(
status_code=200,
content_type="text/event-stream",
content=mock_stream.to_content(),
)

response = await test_app.post(
"/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
json={
"messages": [{"role": "user", "content": "Test content"}],
"stream": True,
},
headers={
"X-UPSTREAM-KEY": "TEST_API_KEY",
"X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions",
},
)

assert response.status_code == 200

expected_final_chunk = single_choice_chunk(
delta={},
finish_reason="length",
usage={"completion_tokens": 1, "prompt_tokens": 9, "total_tokens": 10},
)

expected_stream = OpenAIStream(*mock_stream.chunks, expected_final_chunk)
expected_stream.assert_response_content(response, assert_equal)


@respx.mock
@pytest.mark.asyncio
async def test_incorrect_upstream_url(test_app: httpx.AsyncClient):
Expand Down
Loading

0 comments on commit c9a3fe5

Please sign in to comment.