diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 5346b048..50af2746 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -296,14 +296,22 @@ async def answer( ) core_chat = schema_to_core_chat(session, user=user, chat=chat) - core_answer = await core_chat.answer(prompt, stream=stream) + # smoke test to catch errors unrelated to streaming + # longer-term solution tracked by #409 + try: + core_answer = await core_chat.answer(prompt, stream=stream) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) from e if stream: - - async def message_chunks() -> AsyncIterator[BaseModel]: + # smoke test to catch errors unrelated to streaming + # longer-term solution tracked by #409 + try: core_answer_stream = aiter(core_answer) content_chunk = await anext(core_answer_stream) - answer = schemas.Message( content=content_chunk, role=core_answer.role, @@ -312,6 +320,13 @@ async def message_chunks() -> AsyncIterator[BaseModel]: for source in core_answer.sources ], ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) from e + + async def message_chunks() -> AsyncIterator[BaseModel]: yield answer # Avoid sending the sources multiple times diff --git a/tests/deploy/api/test_streaming.py b/tests/deploy/api/test_streaming.py new file mode 100644 index 00000000..4d5e64da --- /dev/null +++ b/tests/deploy/api/test_streaming.py @@ -0,0 +1,74 @@ +import io +import json + +import pytest +from fastapi.testclient import TestClient + +from ragna.core import Assistant, SourceStorage +from ragna.deploy import Config +from ragna.deploy._api import app + +from .utils import authenticate + + +class FakeAssistant(Assistant): + def answer(self, prompt, sources): + if prompt == "Assistant Error": + raise Exception("Assistant Error") + + +class FakeSourceStorage(SourceStorage): + async def store(self, documents): + pass + + async def retrieve(self, documents, prompt): + if prompt == "SourceStorage Error": + raise Exception("SourceStorage Error") + + +@pytest.mark.parametrize("prompt", ["Assistant Error", "SourceStorage Error"]) +def test_internal_server_error_response(tmp_local_root, prompt): + config = Config( + local_root=tmp_local_root, + assistants=[FakeAssistant], + source_storages=[FakeSourceStorage], + ) + + with TestClient(app(config=config, ignore_unavailable_components=False)) as client: + authenticate(client) + + document_upload = ( + client.post("/document", json={"name": "fake.txt"}) + .raise_for_status() + .json() + ) + document = document_upload["document"] + parameters = document_upload["parameters"] + client.request( + parameters["method"], + parameters["url"], + data=parameters["data"], + files={"file": io.BytesIO(b"!")}, + ) + + chat_metadata = { + "name": "test-chat", + "source_storage": "FakeSourceStorage", + "assistant": "FakeAssistant", + "params": {}, + "documents": [document], + } + chat = client.post("/chats", json=chat_metadata).raise_for_status().json() + + _ = client.post(f"/chats/{chat['id']}/prepare").raise_for_status().json() + + with client.stream( + "POST", + f"/chats/{chat['id']}/answer", + json={"prompt": prompt, "stream": True}, + ) as response: + r = response.read() + + assert response.status_code == 500 + + assert json.loads(r.decode("utf-8"))["detail"] == prompt