Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactored and simplified unit tests #191

Merged
merged 10 commits into from
Nov 28, 2024
12 changes: 5 additions & 7 deletions tests/examples/test_echo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
from fastapi.testclient import TestClient

from examples.echo.app import app

http_client = TestClient(app)
from tests.utils.client import create_test_client


def test_app():
client = create_test_client(app, name="echo")

content = "Hello world!"
attachment = {
"type": "image/png",
"url": "image-url",
"title": "Image",
}

response = http_client.post(
"/openai/deployments/echo/chat/completions?api-version=2023-03-15-preview",
headers={"Api-Key": "dial_api_key"},
response = client.post(
"chat/completions",
json={
"messages": [
{
Expand Down
12 changes: 5 additions & 7 deletions tests/examples/test_image_size.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from fastapi.testclient import TestClient

from examples.image_size.app.main import app

http_client = TestClient(app)
from tests.utils.client import create_test_client


def test_app():
client = create_test_client(app, name="image-size")

attachment = {
"type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==",
"title": "Image",
}

response = http_client.post(
"/openai/deployments/image-size/chat/completions?api-version=2023-03-15-preview",
headers={"Api-Key": "dial_api_key"},
response = client.post(
"chat/completions",
json={
"messages": [
{
Expand Down
11 changes: 5 additions & 6 deletions tests/examples/test_render_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
from io import BytesIO
from typing import Tuple

from fastapi.testclient import TestClient
from PIL import Image

from examples.render_text.app.main import app

http_client = TestClient(app)
from tests.utils.client import create_test_client


def test_app():
response = http_client.post(
"/openai/deployments/render-text/chat/completions?api-version=2023-03-15-preview",
headers={"Api-Key": "dial_api_key"},
client = create_test_client(app, name="render-text")

response = client.post(
"chat/completions",
json={
"messages": [
{
Expand Down
15 changes: 2 additions & 13 deletions tests/test_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

import pytest

from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest
from aidial_sdk.chat_completion.response import (
Response as ChatCompletionResponse,
)
from aidial_sdk.pydantic_v1 import SecretStr
from aidial_sdk.utils.streaming import add_heartbeat
from tests.utils.constants import DUMMY_FASTAPI_REQUEST
from tests.utils.constants import DUMMY_DIAL_REQUEST


class Counter:
Expand All @@ -36,7 +34,6 @@ async def _wait_forever():
async def _wait(counter: Counter, secs: Optional[int] = None):
try:
if secs is None:
# wait forever
await _wait_forever()
else:
for _ in range(secs):
Expand Down Expand Up @@ -88,15 +85,7 @@ async def test_cancellation(
with_heartbeat: bool, chat_completion, expected_cancelled, expected_done
):

request = ChatCompletionRequest(
original_request=DUMMY_FASTAPI_REQUEST,
messages=[],
api_key_secret=SecretStr("api-key"),
deployment_id="test-app",
headers={},
)

response = ChatCompletionResponse(request)
response = ChatCompletionResponse(DUMMY_DIAL_REQUEST)

counter = Counter()
chat_completion = chat_completion(counter)
Expand Down
145 changes: 42 additions & 103 deletions tests/test_discarded_messages.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,28 @@
import json
from unittest.mock import Mock

import pytest
from starlette.testclient import TestClient

from aidial_sdk import DIALApp, HTTPException
from aidial_sdk import HTTPException
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.pydantic_v1 import SecretStr
from tests.utils.constants import DUMMY_FASTAPI_REQUEST
from tests.utils.chunks import check_sse_stream, create_single_choice_chunk
from tests.utils.client import create_app_client
from tests.utils.constants import DUMMY_DIAL_REQUEST

DISCARDED_MESSAGES = list(range(0, 12))


def test_discarded_messages_returned():
dial_app = DIALApp()
chat_completion = Mock(spec=ChatCompletion)

async def chat_completion_side_effect(_, res: Response) -> None:
with res.create_single_choice():
pass
res.set_discarded_messages(DISCARDED_MESSAGES)

chat_completion.chat_completion.side_effect = chat_completion_side_effect
dial_app.add_chat_completion("test_app", chat_completion)

test_app = TestClient(dial_app)

response = test_app.post(
"/openai/deployments/test_app/chat/completions",
json={
"messages": [{"role": "user", "content": "Test content"}],
},
headers={"Api-Key": "TEST_API_KEY"},
class _Impl(ChatCompletion):
async def chat_completion(
self, request: Request, response: Response
) -> None:
with response.create_single_choice():
pass
response.set_discarded_messages(DISCARDED_MESSAGES)

client = create_app_client(_Impl())

response = client.post(
"chat/completions",
json={"messages": [{"role": "user", "content": "Test"}]},
)

assert (
Expand All @@ -41,88 +32,43 @@ async def chat_completion_side_effect(_, res: Response) -> None:


def test_discarded_messages_returned_as_last_chunk_in_stream():
dial_app = DIALApp()
chat_completion = Mock(spec=ChatCompletion)
class _Impl(ChatCompletion):
async def chat_completion(
self, request: Request, response: Response
) -> None:
response.set_response_id("test_id")
response.set_created(0)

async def chat_completion_side_effect(_, res: Response) -> None:
res.set_response_id("test_id")
res.set_created(123)
with response.create_single_choice():
pass

with res.create_single_choice():
pass
response.set_discarded_messages(DISCARDED_MESSAGES)

res.set_discarded_messages(DISCARDED_MESSAGES)
client = create_app_client(_Impl())

chat_completion.chat_completion.side_effect = chat_completion_side_effect
dial_app.add_chat_completion("test_app", chat_completion)

test_app = TestClient(dial_app)

response = test_app.post(
"/openai/deployments/test_app/chat/completions",
response = client.post(
"chat/completions",
json={
"messages": [{"role": "user", "content": "Test content"}],
"stream": True,
},
headers={"Api-Key": "TEST_API_KEY"},
)

def parse_chunk(data: str):
return json.loads(data[len("data: ") :])

def identity(data: str):
return data

parsers = [
parse_chunk,
identity,
parse_chunk,
identity,
identity,
identity,
]
lines = [*response.iter_lines()]

assert len(lines) == len(parsers)
assert [parser(lines[i]) for i, parser in enumerate(parsers)] == [
{
"choices": [
{
"index": 0,
"finish_reason": None,
"delta": {"role": "assistant"},
}
],
"usage": None,
"id": "test_id",
"created": 123,
"object": "chat.completion.chunk",
},
"",
{
"choices": [{"index": 0, "finish_reason": "stop", "delta": {}}],
"usage": None,
"statistics": {"discarded_messages": DISCARDED_MESSAGES},
"id": "test_id",
"created": 123,
"object": "chat.completion.chunk",
},
"",
"data: [DONE]",
"",
]
check_sse_stream(
response.iter_lines(),
[
create_single_choice_chunk({"role": "assistant"}),
create_single_choice_chunk(
{},
finish_reason="stop",
statistics={"discarded_messages": DISCARDED_MESSAGES},
),
],
)


def test_discarded_messages_is_set_twice():
request = Request(
headers={},
original_request=DUMMY_FASTAPI_REQUEST,
api_key_secret=SecretStr("dummy_key"),
deployment_id="",
messages=[],
)

response = Response(request)
response = Response(DUMMY_DIAL_REQUEST)

with response.create_single_choice():
pass
Expand All @@ -134,14 +80,7 @@ def test_discarded_messages_is_set_twice():


def test_discarded_messages_is_set_before_choice():
request = Request(
headers={},
original_request=DUMMY_FASTAPI_REQUEST,
api_key_secret=SecretStr("dummy_key"),
deployment_id="",
messages=[],
)
response = Response(request)
response = Response(DUMMY_DIAL_REQUEST)

with pytest.raises(HTTPException):
response.set_discarded_messages(DISCARDED_MESSAGES)
Loading