Skip to content

Commit

Permalink
Add basic support for passing error messages from API to UI while str…
Browse files Browse the repository at this point in the history
…eaming
  • Loading branch information
nenb committed May 10, 2024
1 parent 3cef0f7 commit ec0f6ce
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
23 changes: 19 additions & 4 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,22 @@ async def answer(
)
core_chat = schema_to_core_chat(session, user=user, chat=chat)

core_answer = await core_chat.answer(prompt, stream=stream)
# smoke test to catch errors unrelated to streaming
# longer-term solution tracked by #409
try:
core_answer = await core_chat.answer(prompt, stream=stream)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e),
) from e

if stream:

async def message_chunks() -> AsyncIterator[BaseModel]:
# smoke test to catch errors unrelated to streaming
# longer-term solution tracked by #409
try:
core_answer_stream = aiter(core_answer)
content_chunk = await anext(core_answer_stream)

answer = schemas.Message(
content=content_chunk,
role=core_answer.role,
Expand All @@ -312,6 +320,13 @@ async def message_chunks() -> AsyncIterator[BaseModel]:
for source in core_answer.sources
],
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e),
) from e

async def message_chunks() -> AsyncIterator[BaseModel]:
yield answer

# Avoid sending the sources multiple times
Expand Down
74 changes: 74 additions & 0 deletions tests/deploy/api/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import io
import json

import pytest
from fastapi.testclient import TestClient

from ragna.core import Assistant, SourceStorage
from ragna.deploy import Config
from ragna.deploy._api import app

from .utils import authenticate


class FakeAssistant(Assistant):
def answer(self, prompt, sources):
if prompt == "Assistant Error":
raise Exception("Assistant Error")


class FakeSourceStorage(SourceStorage):
async def store(self, documents):
pass

async def retrieve(self, documents, prompt):
if prompt == "SourceStorage Error":
raise Exception("SourceStorage Error")


@pytest.mark.parametrize("prompt", ["Assistant Error", "SourceStorage Error"])
def test_internal_server_error_response(tmp_local_root, prompt):
config = Config(
local_root=tmp_local_root,
assistants=[FakeAssistant],
source_storages=[FakeSourceStorage],
)

with TestClient(app(config=config, ignore_unavailable_components=False)) as client:
authenticate(client)

document_upload = (
client.post("/document", json={"name": "fake.txt"})
.raise_for_status()
.json()
)
document = document_upload["document"]
parameters = document_upload["parameters"]
client.request(
parameters["method"],
parameters["url"],
data=parameters["data"],
files={"file": io.BytesIO(b"!")},
)

chat_metadata = {
"name": "test-chat",
"source_storage": "FakeSourceStorage",
"assistant": "FakeAssistant",
"params": {},
"documents": [document],
}
chat = client.post("/chats", json=chat_metadata).raise_for_status().json()

_ = client.post(f"/chats/{chat['id']}/prepare").raise_for_status().json()

with client.stream(
"POST",
f"/chats/{chat['id']}/answer",
json={"prompt": prompt, "stream": True},
) as response:
r = response.read()

assert response.status_code == 500

assert json.loads(r.decode("utf-8"))["detail"] == prompt

0 comments on commit ec0f6ce

Please sign in to comment.