Skip to content

Commit

Permalink
chore: refactored and simplified unit tests (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Nov 28, 2024
1 parent d110cde commit ce3173a
Show file tree
Hide file tree
Showing 15 changed files with 237 additions and 374 deletions.
12 changes: 5 additions & 7 deletions tests/examples/test_echo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
from fastapi.testclient import TestClient

from examples.echo.app import app

http_client = TestClient(app)
from tests.utils.client import create_test_client


def test_app():
client = create_test_client(app, name="echo")

content = "Hello world!"
attachment = {
"type": "image/png",
"url": "image-url",
"title": "Image",
}

response = http_client.post(
"/openai/deployments/echo/chat/completions?api-version=2023-03-15-preview",
headers={"Api-Key": "dial_api_key"},
response = client.post(
"chat/completions",
json={
"messages": [
{
Expand Down
12 changes: 5 additions & 7 deletions tests/examples/test_image_size.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from fastapi.testclient import TestClient

from examples.image_size.app.main import app

http_client = TestClient(app)
from tests.utils.client import create_test_client


def test_app():
client = create_test_client(app, name="image-size")

attachment = {
"type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==",
"title": "Image",
}

response = http_client.post(
"/openai/deployments/image-size/chat/completions?api-version=2023-03-15-preview",
headers={"Api-Key": "dial_api_key"},
response = client.post(
"chat/completions",
json={
"messages": [
{
Expand Down
11 changes: 5 additions & 6 deletions tests/examples/test_render_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
from io import BytesIO
from typing import Tuple

from fastapi.testclient import TestClient
from PIL import Image

from examples.render_text.app.main import app

http_client = TestClient(app)
from tests.utils.client import create_test_client


def test_app():
response = http_client.post(
"/openai/deployments/render-text/chat/completions?api-version=2023-03-15-preview",
headers={"Api-Key": "dial_api_key"},
client = create_test_client(app, name="render-text")

response = client.post(
"chat/completions",
json={
"messages": [
{
Expand Down
15 changes: 2 additions & 13 deletions tests/test_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

import pytest

from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest
from aidial_sdk.chat_completion.response import (
Response as ChatCompletionResponse,
)
from aidial_sdk.pydantic_v1 import SecretStr
from aidial_sdk.utils.streaming import add_heartbeat
from tests.utils.constants import DUMMY_FASTAPI_REQUEST
from tests.utils.constants import DUMMY_DIAL_REQUEST


class Counter:
Expand All @@ -36,7 +34,6 @@ async def _wait_forever():
async def _wait(counter: Counter, secs: Optional[int] = None):
try:
if secs is None:
# wait forever
await _wait_forever()
else:
for _ in range(secs):
Expand Down Expand Up @@ -88,15 +85,7 @@ async def test_cancellation(
with_heartbeat: bool, chat_completion, expected_cancelled, expected_done
):

request = ChatCompletionRequest(
original_request=DUMMY_FASTAPI_REQUEST,
messages=[],
api_key_secret=SecretStr("api-key"),
deployment_id="test-app",
headers={},
)

response = ChatCompletionResponse(request)
response = ChatCompletionResponse(DUMMY_DIAL_REQUEST)

counter = Counter()
chat_completion = chat_completion(counter)
Expand Down
145 changes: 42 additions & 103 deletions tests/test_discarded_messages.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,28 @@
import json
from unittest.mock import Mock

import pytest
from starlette.testclient import TestClient

from aidial_sdk import DIALApp, HTTPException
from aidial_sdk import HTTPException
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.pydantic_v1 import SecretStr
from tests.utils.constants import DUMMY_FASTAPI_REQUEST
from tests.utils.chunks import check_sse_stream, create_single_choice_chunk
from tests.utils.client import create_app_client
from tests.utils.constants import DUMMY_DIAL_REQUEST

DISCARDED_MESSAGES = list(range(0, 12))


def test_discarded_messages_returned():
dial_app = DIALApp()
chat_completion = Mock(spec=ChatCompletion)

async def chat_completion_side_effect(_, res: Response) -> None:
with res.create_single_choice():
pass
res.set_discarded_messages(DISCARDED_MESSAGES)

chat_completion.chat_completion.side_effect = chat_completion_side_effect
dial_app.add_chat_completion("test_app", chat_completion)

test_app = TestClient(dial_app)

response = test_app.post(
"/openai/deployments/test_app/chat/completions",
json={
"messages": [{"role": "user", "content": "Test content"}],
},
headers={"Api-Key": "TEST_API_KEY"},
class _Impl(ChatCompletion):
async def chat_completion(
self, request: Request, response: Response
) -> None:
with response.create_single_choice():
pass
response.set_discarded_messages(DISCARDED_MESSAGES)

client = create_app_client(_Impl())

response = client.post(
"chat/completions",
json={"messages": [{"role": "user", "content": "Test"}]},
)

assert (
Expand All @@ -41,88 +32,43 @@ async def chat_completion_side_effect(_, res: Response) -> None:


def test_discarded_messages_returned_as_last_chunk_in_stream():
dial_app = DIALApp()
chat_completion = Mock(spec=ChatCompletion)
class _Impl(ChatCompletion):
async def chat_completion(
self, request: Request, response: Response
) -> None:
response.set_response_id("test_id")
response.set_created(0)

async def chat_completion_side_effect(_, res: Response) -> None:
res.set_response_id("test_id")
res.set_created(123)
with response.create_single_choice():
pass

with res.create_single_choice():
pass
response.set_discarded_messages(DISCARDED_MESSAGES)

res.set_discarded_messages(DISCARDED_MESSAGES)
client = create_app_client(_Impl())

chat_completion.chat_completion.side_effect = chat_completion_side_effect
dial_app.add_chat_completion("test_app", chat_completion)

test_app = TestClient(dial_app)

response = test_app.post(
"/openai/deployments/test_app/chat/completions",
response = client.post(
"chat/completions",
json={
"messages": [{"role": "user", "content": "Test content"}],
"stream": True,
},
headers={"Api-Key": "TEST_API_KEY"},
)

def parse_chunk(data: str):
return json.loads(data[len("data: ") :])

def identity(data: str):
return data

parsers = [
parse_chunk,
identity,
parse_chunk,
identity,
identity,
identity,
]
lines = [*response.iter_lines()]

assert len(lines) == len(parsers)
assert [parser(lines[i]) for i, parser in enumerate(parsers)] == [
{
"choices": [
{
"index": 0,
"finish_reason": None,
"delta": {"role": "assistant"},
}
],
"usage": None,
"id": "test_id",
"created": 123,
"object": "chat.completion.chunk",
},
"",
{
"choices": [{"index": 0, "finish_reason": "stop", "delta": {}}],
"usage": None,
"statistics": {"discarded_messages": DISCARDED_MESSAGES},
"id": "test_id",
"created": 123,
"object": "chat.completion.chunk",
},
"",
"data: [DONE]",
"",
]
check_sse_stream(
response.iter_lines(),
[
create_single_choice_chunk({"role": "assistant"}),
create_single_choice_chunk(
{},
finish_reason="stop",
statistics={"discarded_messages": DISCARDED_MESSAGES},
),
],
)


def test_discarded_messages_is_set_twice():
request = Request(
headers={},
original_request=DUMMY_FASTAPI_REQUEST,
api_key_secret=SecretStr("dummy_key"),
deployment_id="",
messages=[],
)

response = Response(request)
response = Response(DUMMY_DIAL_REQUEST)

with response.create_single_choice():
pass
Expand All @@ -134,14 +80,7 @@ def test_discarded_messages_is_set_twice():


def test_discarded_messages_is_set_before_choice():
request = Request(
headers={},
original_request=DUMMY_FASTAPI_REQUEST,
api_key_secret=SecretStr("dummy_key"),
deployment_id="",
messages=[],
)
response = Response(request)
response = Response(DUMMY_DIAL_REQUEST)

with pytest.raises(HTTPException):
response.set_discarded_messages(DISCARDED_MESSAGES)
Loading

0 comments on commit ce3173a

Please sign in to comment.