diff --git a/tests/examples/test_echo.py b/tests/examples/test_echo.py index bddfba6..88c1b07 100644 --- a/tests/examples/test_echo.py +++ b/tests/examples/test_echo.py @@ -1,11 +1,10 @@ -from fastapi.testclient import TestClient - from examples.echo.app import app - -http_client = TestClient(app) +from tests.utils.client import create_test_client def test_app(): + client = create_test_client(app, name="echo") + content = "Hello world!" attachment = { "type": "image/png", @@ -13,9 +12,8 @@ def test_app(): "title": "Image", } - response = http_client.post( - "/openai/deployments/echo/chat/completions?api-version=2023-03-15-preview", - headers={"Api-Key": "dial_api_key"}, + response = client.post( + "chat/completions", json={ "messages": [ { diff --git a/tests/examples/test_image_size.py b/tests/examples/test_image_size.py index fcde99e..5d15ed0 100644 --- a/tests/examples/test_image_size.py +++ b/tests/examples/test_image_size.py @@ -1,20 +1,18 @@ -from fastapi.testclient import TestClient - from examples.image_size.app.main import app - -http_client = TestClient(app) +from tests.utils.client import create_test_client def test_app(): + client = create_test_client(app, name="image-size") + attachment = { "type": "image/png", "data": "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==", "title": "Image", } - response = http_client.post( - "/openai/deployments/image-size/chat/completions?api-version=2023-03-15-preview", - headers={"Api-Key": "dial_api_key"}, + response = client.post( + "chat/completions", json={ "messages": [ { diff --git a/tests/examples/test_render_text.py b/tests/examples/test_render_text.py index 224ea98..c6656a8 100644 --- a/tests/examples/test_render_text.py +++ b/tests/examples/test_render_text.py @@ -2,18 +2,17 @@ from io import BytesIO from typing import Tuple -from fastapi.testclient import TestClient from PIL import Image from examples.render_text.app.main import app - -http_client = TestClient(app) +from tests.utils.client import create_test_client def test_app(): - response = http_client.post( - "/openai/deployments/render-text/chat/completions?api-version=2023-03-15-preview", - headers={"Api-Key": "dial_api_key"}, + client = create_test_client(app, name="render-text") + + response = client.post( + "chat/completions", json={ "messages": [ { diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index d062760..bbc8417 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -3,13 +3,11 @@ import pytest -from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest from aidial_sdk.chat_completion.response import ( Response as ChatCompletionResponse, ) -from aidial_sdk.pydantic_v1 import SecretStr from aidial_sdk.utils.streaming import add_heartbeat -from tests.utils.constants import DUMMY_FASTAPI_REQUEST +from tests.utils.constants import DUMMY_DIAL_REQUEST class Counter: @@ -36,7 +34,6 @@ async def _wait_forever(): async def _wait(counter: Counter, secs: Optional[int] = None): try: if secs is None: - # wait forever await _wait_forever() else: for _ in range(secs): @@ -88,15 +85,7 @@ async def test_cancellation( with_heartbeat: bool, chat_completion, expected_cancelled, expected_done ): - request = ChatCompletionRequest( - original_request=DUMMY_FASTAPI_REQUEST, - messages=[], - api_key_secret=SecretStr("api-key"), - deployment_id="test-app", - headers={}, - ) - - response = ChatCompletionResponse(request) + response = ChatCompletionResponse(DUMMY_DIAL_REQUEST) counter = Counter() chat_completion = chat_completion(counter) diff --git a/tests/test_discarded_messages.py b/tests/test_discarded_messages.py index f9731cd..f902d7c 100644 --- a/tests/test_discarded_messages.py +++ b/tests/test_discarded_messages.py @@ -1,37 +1,28 @@ -import json -from unittest.mock import Mock - import pytest -from starlette.testclient import TestClient -from aidial_sdk import DIALApp, HTTPException +from aidial_sdk import HTTPException from aidial_sdk.chat_completion import ChatCompletion, Request, Response -from aidial_sdk.pydantic_v1 import SecretStr -from tests.utils.constants import DUMMY_FASTAPI_REQUEST +from tests.utils.chunks import check_sse_stream, create_single_choice_chunk +from tests.utils.client import create_app_client +from tests.utils.constants import DUMMY_DIAL_REQUEST DISCARDED_MESSAGES = list(range(0, 12)) def test_discarded_messages_returned(): - dial_app = DIALApp() - chat_completion = Mock(spec=ChatCompletion) - - async def chat_completion_side_effect(_, res: Response) -> None: - with res.create_single_choice(): - pass - res.set_discarded_messages(DISCARDED_MESSAGES) - - chat_completion.chat_completion.side_effect = chat_completion_side_effect - dial_app.add_chat_completion("test_app", chat_completion) - - test_app = TestClient(dial_app) - - response = test_app.post( - "/openai/deployments/test_app/chat/completions", - json={ - "messages": [{"role": "user", "content": "Test content"}], - }, - headers={"Api-Key": "TEST_API_KEY"}, + class _Impl(ChatCompletion): + async def chat_completion( + self, request: Request, response: Response + ) -> None: + with response.create_single_choice(): + pass + response.set_discarded_messages(DISCARDED_MESSAGES) + + client = create_app_client(_Impl()) + + response = client.post( + "chat/completions", + json={"messages": [{"role": "user", "content": "Test"}]}, ) assert ( @@ -41,88 +32,43 @@ async def chat_completion_side_effect(_, res: Response) -> None: def test_discarded_messages_returned_as_last_chunk_in_stream(): - dial_app = DIALApp() - chat_completion = Mock(spec=ChatCompletion) + class _Impl(ChatCompletion): + async def chat_completion( + self, request: Request, response: Response + ) -> None: + response.set_response_id("test_id") + response.set_created(0) - async def chat_completion_side_effect(_, res: Response) -> None: - res.set_response_id("test_id") - res.set_created(123) + with response.create_single_choice(): + pass - with res.create_single_choice(): - pass + response.set_discarded_messages(DISCARDED_MESSAGES) - res.set_discarded_messages(DISCARDED_MESSAGES) + client = create_app_client(_Impl()) - chat_completion.chat_completion.side_effect = chat_completion_side_effect - dial_app.add_chat_completion("test_app", chat_completion) - - test_app = TestClient(dial_app) - - response = test_app.post( - "/openai/deployments/test_app/chat/completions", + response = client.post( + "chat/completions", json={ "messages": [{"role": "user", "content": "Test content"}], "stream": True, }, - headers={"Api-Key": "TEST_API_KEY"}, ) - def parse_chunk(data: str): - return json.loads(data[len("data: ") :]) - - def identity(data: str): - return data - - parsers = [ - parse_chunk, - identity, - parse_chunk, - identity, - identity, - identity, - ] - lines = [*response.iter_lines()] - - assert len(lines) == len(parsers) - assert [parser(lines[i]) for i, parser in enumerate(parsers)] == [ - { - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": {"role": "assistant"}, - } - ], - "usage": None, - "id": "test_id", - "created": 123, - "object": "chat.completion.chunk", - }, - "", - { - "choices": [{"index": 0, "finish_reason": "stop", "delta": {}}], - "usage": None, - "statistics": {"discarded_messages": DISCARDED_MESSAGES}, - "id": "test_id", - "created": 123, - "object": "chat.completion.chunk", - }, - "", - "data: [DONE]", - "", - ] + check_sse_stream( + response.iter_lines(), + [ + create_single_choice_chunk({"role": "assistant"}), + create_single_choice_chunk( + {}, + finish_reason="stop", + statistics={"discarded_messages": DISCARDED_MESSAGES}, + ), + ], + ) def test_discarded_messages_is_set_twice(): - request = Request( - headers={}, - original_request=DUMMY_FASTAPI_REQUEST, - api_key_secret=SecretStr("dummy_key"), - deployment_id="", - messages=[], - ) - - response = Response(request) + response = Response(DUMMY_DIAL_REQUEST) with response.create_single_choice(): pass @@ -134,14 +80,7 @@ def test_discarded_messages_is_set_twice(): def test_discarded_messages_is_set_before_choice(): - request = Request( - headers={}, - original_request=DUMMY_FASTAPI_REQUEST, - api_key_secret=SecretStr("dummy_key"), - deployment_id="", - messages=[], - ) - response = Response(request) + response = Response(DUMMY_DIAL_REQUEST) with pytest.raises(HTTPException): response.set_discarded_messages(DISCARDED_MESSAGES) diff --git a/tests/test_errors.py b/tests/test_errors.py index c06969d..a1fa448 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,16 +1,15 @@ import dataclasses -import json from typing import Any, Dict, List import pytest -from starlette.testclient import TestClient -from aidial_sdk import DIALApp from tests.applications.broken import ( ImmediatelyBrokenApplication, RuntimeBrokenApplication, ) from tests.applications.noop import NoopApplication +from tests.utils.chunks import check_sse_stream, create_single_choice_chunk +from tests.utils.client import create_app_client DEFAULT_RUNTIME_ERROR = { "error": { @@ -103,13 +102,10 @@ class ErrorTestCase: @pytest.mark.parametrize("test_case", error_testcases) def test_error(test_case: ErrorTestCase): - dial_app = DIALApp() - dial_app.add_chat_completion("test_app", ImmediatelyBrokenApplication()) + client = create_app_client(ImmediatelyBrokenApplication()) - test_app = TestClient(dial_app) - - response = test_app.post( - "/openai/deployments/test_app/chat/completions", + response = client.post( + "chat/completions", json={ "messages": [{"role": "user", "content": test_case.content}], "stream": False, @@ -126,13 +122,10 @@ def test_error(test_case: ErrorTestCase): @pytest.mark.parametrize("test_case", error_testcases) def test_streaming_error(test_case: ErrorTestCase): - dial_app = DIALApp() - dial_app.add_chat_completion("test_app", ImmediatelyBrokenApplication()) - - test_app = TestClient(dial_app) + client = create_app_client(ImmediatelyBrokenApplication()) - response = test_app.post( - "/openai/deployments/test_app/chat/completions", + response = client.post( + "chat/completions", json={ "messages": [{"role": "user", "content": test_case.content}], "stream": True, @@ -146,78 +139,32 @@ def test_streaming_error(test_case: ErrorTestCase): @pytest.mark.parametrize("test_case", error_testcases) def test_runtime_streaming_error(test_case: ErrorTestCase): - dial_app = DIALApp() - dial_app.add_chat_completion("test_app", RuntimeBrokenApplication()) - - test_app = TestClient(dial_app) + client = create_app_client(RuntimeBrokenApplication()) - response = test_app.post( - "/openai/deployments/test_app/chat/completions", + response = client.post( + "chat/completions", json={ "messages": [{"role": "user", "content": test_case.content}], "stream": True, }, - headers={"Api-Key": "TEST_API_KEY"}, ) - for index, value in enumerate(response.iter_lines()): - if index % 2: - assert value == "" - continue - - assert value.startswith("data: ") - data = value[6:] - - if index == 0: - assert json.loads(data) == { - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": {"role": "assistant"}, - } - ], - "usage": None, - "id": "test_id", - "created": 0, - "object": "chat.completion.chunk", - } - elif index == 2: - assert json.loads(data) == { - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": {"content": "Test content"}, - } - ], - "usage": None, - "id": "test_id", - "created": 0, - "object": "chat.completion.chunk", - } - elif index == 4: - assert json.loads(data) == { - "choices": [{"index": 0, "finish_reason": "stop", "delta": {}}], - "usage": None, - "id": "test_id", - "created": 0, - "object": "chat.completion.chunk", - } - elif index == 6: - assert json.loads(data) == test_case.response_error - elif index == 8: - assert data == "[DONE]" + check_sse_stream( + response.iter_lines(), + [ + create_single_choice_chunk({"role": "assistant"}), + create_single_choice_chunk({"content": "Test content"}), + create_single_choice_chunk({}, "stop"), + test_case.response_error, + ], + ) def test_no_api_key(): - dial_app = DIALApp() - dial_app.add_chat_completion("test_app", NoopApplication()) - - test_app = TestClient(dial_app) + client = create_app_client(NoopApplication(), headers={}) - response = test_app.post( - "/openai/deployments/test_app/chat/completions", + response = client.post( + "chat/completions", json={ "messages": [{"role": "user", "content": "test"}], "stream": False, diff --git a/tests/test_heartbeat.py b/tests/test_heartbeat.py index 45db538..3d5cedc 100644 --- a/tests/test_heartbeat.py +++ b/tests/test_heartbeat.py @@ -25,57 +25,36 @@ """ import asyncio -import itertools -import json from contextlib import contextmanager -from typing import Generator, Iterator, List, Optional, Union +from typing import List, Optional, Union from unittest.mock import patch import pytest from pydantic import BaseModel -from starlette.testclient import TestClient -from aidial_sdk import DIALApp +from aidial_sdk.application import DIALApp from aidial_sdk.utils.streaming import add_heartbeat as original_add_heartbeat from tests.applications.idle import IdleApplication - -ExpectedStream = List[Union[str, dict]] +from tests.utils.chunks import check_sse_stream, create_single_choice_chunk +from tests.utils.client import create_test_client BEAT = ": heartbeat" -DONE = "data: [DONE]" - -ERROR = "data: " + json.dumps( - { - "error": { - "message": "Error during processing the request", - "type": "runtime_error", - "code": "500", - } - }, - separators=(",", ":"), -) - -def create_choice( - *, finish_reason: Optional[str] = None, delta: Optional[dict] = {} -) -> dict: - return { - "choices": [ - {"index": 0, "finish_reason": finish_reason, "delta": delta} - ], - "usage": None, - "id": "test_id", - "created": 0, - "object": "chat.completion.chunk", +ERROR = { + "error": { + "message": "Error during processing the request", + "type": "runtime_error", + "code": "500", } +} -CHOICE_OPEN = create_choice(delta={"role": "assistant"}) -CHOICE_CLOSE = create_choice(finish_reason="stop") +CHOICE_OPEN = create_single_choice_chunk(delta={"role": "assistant"}) +CHOICE_CLOSE = create_single_choice_chunk(finish_reason="stop") def content(content: str): - return create_choice(delta={"content": content}) + return create_single_choice_chunk(delta={"content": content}) @contextmanager @@ -89,34 +68,13 @@ def _updated_add_heartbeat(*args, **kwargs): yield mock -def match_sse_stream(expected: ExpectedStream, actual: Iterator[str]): - - def _add_newlines( - stream: ExpectedStream, - ) -> Generator[Union[str, dict], None, None]: - for line in stream: - yield line - yield "" - - for expected_item, actual_line in itertools.zip_longest( - _add_newlines(expected), actual - ): - if isinstance(expected_item, dict): - assert actual_line[: len("data:")] == "data:" - actual_line = actual_line[len("data:") :] - actual_obj = json.loads(actual_line) - assert actual_obj == expected_item - else: - assert actual_line == expected_item - - class TestCase(BaseModel): __test__ = False intervals: List[float] throw_exception: bool heartbeat_interval: Optional[float] - expected: ExpectedStream + expected: List[Union[str, dict]] @pytest.mark.parametrize( @@ -131,7 +89,6 @@ class TestCase(BaseModel): CHOICE_OPEN, content("1"), CHOICE_CLOSE, - DONE, ], ), TestCase( @@ -145,7 +102,6 @@ class TestCase(BaseModel): BEAT, content("2"), CHOICE_CLOSE, - DONE, ], ), TestCase( @@ -163,7 +119,6 @@ class TestCase(BaseModel): BEAT, content("4"), CHOICE_CLOSE, - DONE, ], ), TestCase( @@ -178,7 +133,6 @@ class TestCase(BaseModel): CHOICE_OPEN, content("1"), CHOICE_CLOSE, - DONE, ], ), TestCase( @@ -192,7 +146,6 @@ class TestCase(BaseModel): content("3"), content("4"), CHOICE_CLOSE, - DONE, ], ), TestCase( @@ -203,7 +156,6 @@ class TestCase(BaseModel): CHOICE_OPEN, content("1"), CHOICE_CLOSE, - DONE, ], ), TestCase( @@ -216,7 +168,6 @@ class TestCase(BaseModel): content("1"), CHOICE_CLOSE, ERROR, - DONE, ], ), ], @@ -229,11 +180,9 @@ def inc_beat_counter(): beats += 1 with mock_add_heartbeat(heartbeat_callback=inc_beat_counter): - app_name = "test-app" - - app = DIALApp() - app.add_chat_completion( - app_name, + name = "test-deployment-name" + app = DIALApp().add_chat_completion( + name, IdleApplication( intervals=test_case.intervals, throw_exception=test_case.throw_exception, @@ -241,18 +190,17 @@ def inc_beat_counter(): heartbeat_interval=test_case.heartbeat_interval, ) - client = TestClient(app) + client = create_test_client(app, name=name) response = client.post( - url=f"/openai/deployments/{app_name}/chat/completions", + url="chat/completions", json={ "messages": [{"role": "user", "content": "hello"}], "stream": True, }, - headers={"Api-Key": "TEST_API_KEY"}, ) - match_sse_stream(test_case.expected, response.iter_lines()) + check_sse_stream(response.iter_lines(), test_case.expected) expected_beats = test_case.expected.count(BEAT) assert beats == expected_beats diff --git a/tests/test_max_prompt_tokens.py b/tests/test_max_prompt_tokens.py index 102b828..8f827c6 100644 --- a/tests/test_max_prompt_tokens.py +++ b/tests/test_max_prompt_tokens.py @@ -3,7 +3,7 @@ def test_max_prompt_tokens_is_set(): validate_chat_completion( - input_request={ + request={ "messages": [{"role": "user", "content": "Test content"}], "max_prompt_tokens": 15, }, @@ -13,7 +13,7 @@ def test_max_prompt_tokens_is_set(): def test_max_prompt_tokens_is_unset(): validate_chat_completion( - input_request={ + request={ "messages": [{"role": "user", "content": "Test content"}], }, request_validator=lambda r: not r.max_prompt_tokens, diff --git a/tests/test_request_tools_parsing.py b/tests/test_request_tools_parsing.py index e2bfbdb..2f90e92 100644 --- a/tests/test_request_tools_parsing.py +++ b/tests/test_request_tools_parsing.py @@ -113,6 +113,6 @@ def _request_validator(r: Request): assert isinstance(tool, StaticTool) validate_chat_completion( - input_request=mock_data, + request=mock_data, request_validator=_request_validator, ) diff --git a/tests/test_single_choice.py b/tests/test_single_choice.py index 5acdc22..4ffed78 100644 --- a/tests/test_single_choice.py +++ b/tests/test_single_choice.py @@ -1,24 +1,17 @@ -import json - -from starlette.testclient import TestClient - -from aidial_sdk import DIALApp from tests.applications.single_choice import SingleChoiceApplication +from tests.utils.chunks import check_sse_stream, create_single_choice_chunk +from tests.utils.client import create_app_client def test_single_choice(): - dial_app = DIALApp() - dial_app.add_chat_completion("test_app", SingleChoiceApplication()) - - test_app = TestClient(dial_app) + client = create_app_client(SingleChoiceApplication()) - response = test_app.post( - "/openai/deployments/test_app/chat/completions", + response = client.post( + "chat/completions", json={ "messages": [{"role": "user", "content": "Test content"}], "stream": False, }, - headers={"Api-Key": "TEST_API_KEY"}, ) assert response.status_code == 200 and response.json() == { @@ -40,63 +33,21 @@ def test_single_choice(): def test_single_choice_streaming(): - dial_app = DIALApp() - dial_app.add_chat_completion("test_app", SingleChoiceApplication()) + client = create_app_client(SingleChoiceApplication()) - test_app = TestClient(dial_app) - - response = test_app.post( - "/openai/deployments/test_app/chat/completions", + response = client.post( + "chat/completions", json={ "messages": [{"role": "user", "content": "Test content"}], "stream": True, }, - headers={"Api-Key": "TEST_API_KEY"}, ) - for index, value in enumerate(response.iter_lines()): - if index % 2: - assert value == "" - continue - - assert value.startswith("data: ") - data = value[6:] - - if index == 0: - assert json.loads(data) == { - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": {"role": "assistant"}, - } - ], - "usage": None, - "id": "test_id", - "created": 0, - "object": "chat.completion.chunk", - } - elif index == 2: - assert json.loads(data) == { - "choices": [ - { - "index": 0, - "finish_reason": None, - "delta": {"content": "Test response content"}, - } - ], - "usage": None, - "id": "test_id", - "created": 0, - "object": "chat.completion.chunk", - } - elif index == 4: - assert json.loads(data) == { - "choices": [{"index": 0, "finish_reason": "stop", "delta": {}}], - "usage": None, - "id": "test_id", - "created": 0, - "object": "chat.completion.chunk", - } - elif index == 6: - assert data == "[DONE]" + check_sse_stream( + response.iter_lines(), + [ + create_single_choice_chunk({"role": "assistant"}), + create_single_choice_chunk({"content": "Test response content"}), + create_single_choice_chunk({}, "stop"), + ], + ) diff --git a/tests/utils/chat_completion_validation.py b/tests/utils/chat_completion_validation.py index 327279f..5fcd147 100644 --- a/tests/utils/chat_completion_validation.py +++ b/tests/utils/chat_completion_validation.py @@ -1,25 +1,12 @@ -from fastapi.testclient import TestClient - -from aidial_sdk.application import DIALApp from tests.applications.validator import RequestValidator, ValidatorApplication +from tests.utils.client import create_app_client def validate_chat_completion( - input_request: dict, - request_validator: RequestValidator, + request: dict, request_validator: RequestValidator ) -> None: - dial_app = DIALApp() - dial_app.add_chat_completion( - "test_app", - ValidatorApplication( - request_validator=request_validator, - ), + client = create_app_client( + ValidatorApplication(request_validator=request_validator) ) - test_app = TestClient(dial_app) - - test_app.post( - "/openai/deployments/test_app/chat/completions", - json=input_request, - headers={"Api-Key": "TEST_API_KEY"}, - ) + client.post("chat/completions", json=request) diff --git a/tests/utils/chunks.py b/tests/utils/chunks.py index 1501b49..eb13c95 100644 --- a/tests/utils/chunks.py +++ b/tests/utils/chunks.py @@ -1,4 +1,6 @@ -from typing import Literal, Optional +import itertools +import json +from typing import Iterable, Literal, Optional, Union def create_chunk( @@ -23,6 +25,25 @@ def create_chunk( } +def create_single_choice_chunk( + delta: dict = {}, finish_reason: Optional[str] = None, **kwargs +): + return { + "choices": [ + { + "index": 0, + "finish_reason": finish_reason, + "delta": delta, + } + ], + "usage": None, + "id": "test_id", + "created": 0, + "object": "chat.completion.chunk", + **kwargs, + } + + def create_tool_call_chunk( idx: int, *, @@ -43,3 +64,47 @@ def create_tool_call_chunk( ] } ) + + +def _check_sse_line(actual: str, expected: Union[str, dict]): + if isinstance(expected, str): + assert ( + actual == expected + ), f"actual line != expected line: {actual!r} != {expected!r}" + return + + assert actual.startswith("data: "), f"Invalid data SSE entry: {actual!r}" + actual = actual[len("data: ") :] + + try: + actual_dict = json.loads(actual) + except json.JSONDecodeError: + raise AssertionError(f"Invalid JSON in data SSE entry: {actual!r}") + assert ( + actual_dict == expected + ), f"actual json != expected json: {actual_dict!r} != {expected!r}" + + +ExpectedSSEStream = Iterable[Union[str, dict]] + + +def check_sse_stream( + actual: Iterable[str], expected: ExpectedSSEStream +) -> bool: + expected = itertools.chain(expected, ["data: [DONE]"]) + expected = itertools.chain.from_iterable((line, "") for line in expected) + + sentinel = object() + for a_line, e_obj in itertools.zip_longest( + actual, expected, fillvalue=sentinel + ): + assert ( + a_line is not sentinel + ), "The list of actual values is shorter than the list of expected values" + assert ( + e_obj is not sentinel + ), "The list of expected values is shorter than the list of actual values" + + _check_sse_line(a_line, e_obj) # type: ignore + + return True diff --git a/tests/utils/client.py b/tests/utils/client.py new file mode 100644 index 0000000..9b22143 --- /dev/null +++ b/tests/utils/client.py @@ -0,0 +1,31 @@ +from typing import Dict + +import httpx +from fastapi import FastAPI +from starlette.testclient import TestClient + +from aidial_sdk import DIALApp +from aidial_sdk.chat_completion.base import ChatCompletion + + +def create_app_client( + chat_completion: ChatCompletion, + *, + name: str = "test-deployment-name", + headers: Dict[str, str] = {"api-key": "TEST_API_KEY"}, +) -> httpx.Client: + app = DIALApp().add_chat_completion(name, chat_completion) + return create_test_client(app, name=name, headers=headers) + + +def create_test_client( + app: FastAPI, + *, + name: str = "test-deployment-name", + headers: Dict[str, str] = {"api-key": "TEST_API_KEY"}, +) -> httpx.Client: + return TestClient( + app=app, + headers=headers, + base_url=f"http://testserver/openai/deployments/{name}", + ) diff --git a/tests/utils/constants.py b/tests/utils/constants.py index 4190174..dad59c5 100644 --- a/tests/utils/constants.py +++ b/tests/utils/constants.py @@ -1,3 +1,14 @@ import fastapi -DUMMY_FASTAPI_REQUEST = fastapi.Request({"type": "http"}) +from aidial_sdk.chat_completion import Request +from aidial_sdk.pydantic_v1 import SecretStr + +_DUMMY_FASTAPI_REQUEST = fastapi.Request({"type": "http"}) + +DUMMY_DIAL_REQUEST = Request( + headers={}, + original_request=_DUMMY_FASTAPI_REQUEST, + api_key_secret=SecretStr("dummy_key"), + deployment_id="", + messages=[], +) diff --git a/tests/utils/endpoint_test.py b/tests/utils/endpoint_test.py index 9f2a5f8..eb94f7f 100644 --- a/tests/utils/endpoint_test.py +++ b/tests/utils/endpoint_test.py @@ -37,9 +37,9 @@ def __init__( def run_endpoint_test(testcase: TestCase): - test_app = TestClient(testcase.app) + client = TestClient(testcase.app) - actual_response = test_app.post( + actual_response = client.post( f"/openai/deployments/{testcase.deployment}/{testcase.endpoint}", json=testcase.request_body, headers={"Api-Key": "TEST_API_KEY", **testcase.request_headers},