diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index 278c776cc..e19cb2cfb 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -4,25 +4,45 @@ import pytest from fastapi.testclient import TestClient +from ragna.assistants import RagnaDemoAssistant from ragna.deploy import Config from ragna.deploy._api import app from .utils import authenticate -@pytest.mark.parametrize("stream_answer", [True, False]) -def test_e2e(tmp_local_root, stream_answer): - config = Config(local_root=tmp_local_root) - check_api(config, stream_answer=stream_answer) +class TestAssistant(RagnaDemoAssistant): + @property + def max_input_size(self) -> int: + return 0 + + def answer(self, prompt, sources, *, multiple_answer_chunks: bool): + content = next(super().answer(prompt, sources)) + + if multiple_answer_chunks: + for chunk in content.split(" "): + yield f"{chunk} " + else: + yield content -def check_api(config, *, stream_answer): +@pytest.mark.parametrize("multiple_answer_chunks", [True, False]) +@pytest.mark.parametrize("stream_answer", [True, False]) +def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): + config = Config(local_root=tmp_local_root, assistants=[TestAssistant]) + document_root = config.local_root / "documents" document_root.mkdir() document_path = document_root / "test.txt" with open(document_path, "w") as file: file.write("!\n") + # Reset starlette_sse AppStatus for each run + # See https://github.com/sysid/sse-starlette/issues/59 + from sse_starlette.sse import AppStatus + + AppStatus.should_exit_event = None + with TestClient(app(config=config, ignore_unavailable_components=False)) as client: authenticate(client) @@ -66,7 +86,7 @@ def check_api(config, *, stream_answer): "name": "test-chat", "source_storage": source_storage, "assistant": assistant, - "params": {}, + "params": {"multiple_answer_chunks": multiple_answer_chunks}, "documents": [document], } chat = client.post("/chats", json=chat_metadata).raise_for_status().json() @@ -96,10 +116,10 @@ def check_api(config, *, stream_answer): json={"prompt": prompt, "stream": True}, ) as event_source: for sse in event_source.iter_sse(): - chunk = json.loads(sse.data) - chunks.append(chunk["content"]) - message = chunk - message["content"] = "".join(chunks) + chunks.append(json.loads(sse.data)) + message = chunks[0] + assert all(chunk["sources"] is None for chunk in chunks[1:]) + message["content"] = "".join(chunk["content"] for chunk in chunks) else: message = ( client.post(f"/chats/{chat['id']}/answer", json={"prompt": prompt})