Skip to content

Commit

Permalink
Add basic support for passing error messages from API to UI while str…
Browse files Browse the repository at this point in the history
…eaming
  • Loading branch information
nenb committed May 10, 2024
1 parent 3cef0f7 commit f3c9770
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
11 changes: 10 additions & 1 deletion ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,16 @@ async def answer(
)
core_chat = schema_to_core_chat(session, user=user, chat=chat)

core_answer = await core_chat.answer(prompt, stream=stream)
# smoke test to catch errors unrelated to streaming
# longer-term solution tracked by #409
try:
core_answer = await core_chat.answer(prompt, stream=stream)
_ = await anext(aiter(core_answer))
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e),
) from e

if stream:

Expand Down
74 changes: 74 additions & 0 deletions tests/deploy/api/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import io
import json

import pytest
from fastapi.testclient import TestClient

from ragna.core import Assistant, SourceStorage
from ragna.deploy import Config
from ragna.deploy._api import app

from .utils import authenticate


class FakeAssistant(Assistant):
def answer(self, prompt, sources):
if prompt == "Assistant Error":
raise Exception("Assistant Error")


class FakeSourceStorage(SourceStorage):
async def store(self, documents):
pass

async def retrieve(self, documents, prompt):
if prompt == "SourceStorage Error":
raise Exception("SourceStorage Error")


@pytest.mark.parametrize("prompt", ["Assistant Error", "SourceStorage Error"])
def test_internal_server_error_response(tmp_local_root, prompt):
config = Config(
local_root=tmp_local_root,
assistants=[FakeAssistant],
source_storages=[FakeSourceStorage],
)

with TestClient(app(config=config, ignore_unavailable_components=False)) as client:
authenticate(client)

document_upload = (
client.post("/document", json={"name": "fake.txt"})
.raise_for_status()
.json()
)
document = document_upload["document"]
parameters = document_upload["parameters"]
client.request(
parameters["method"],
parameters["url"],
data=parameters["data"],
files={"file": io.BytesIO(b"!")},
)

chat_metadata = {
"name": "test-chat",
"source_storage": "FakeSourceStorage",
"assistant": "FakeAssistant",
"params": {},
"documents": [document],
}
chat = client.post("/chats", json=chat_metadata).raise_for_status().json()

_ = client.post(f"/chats/{chat['id']}/prepare").raise_for_status().json()

with client.stream(
"POST",
f"/chats/{chat['id']}/answer",
json={"prompt": prompt, "stream": True},
) as response:
r = response.read()

assert response.status_code == 500

assert json.loads(r.decode("utf-8"))["detail"] == prompt

0 comments on commit f3c9770

Please sign in to comment.