Skip to content

Commit

Permalink
improve error handling for builtin assistants (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Mar 11, 2024
1 parent 329a1a0 commit 1043241
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 17 deletions.
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- python-dotenv
- pytest >=6
- pytest-mock
- pytest-asyncio
- mypy ==1.6.1
- pre-commit
- types-aiofiles
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,17 @@ ignore = ["E501"]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra --tb=short"
addopts = "-ra --tb=short --asyncio-mode=auto"
testpaths = [
"tests",
]
filterwarnings = [
"error",
"ignore::ResourceWarning",
# httpx 0.27.0 deprecated some functionality that the test client of starlette /
# FastApi use. This should be resolved by the next release of these libraries.
# See https://github.com/encode/starlette/issues/2524
"ignore:The 'app' shortcut is now deprecated:DeprecationWarning"
"ignore:The 'app' shortcut is now deprecated:DeprecationWarning",
]
xfail_strict = true

Expand Down
8 changes: 2 additions & 6 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
from ragna.core import Source

from ._api import ApiAssistant

Expand Down Expand Up @@ -53,11 +53,7 @@ async def _call_api(
"system": self._make_system_content(sources),
},
)

if response.is_error:
raise RagnaException(
status_code=response.status_code, response=response.json()
)
await self._assert_api_call_is_success(response)

yield cast(str, response.json()["outputs"][0]["text"])

Expand Down
2 changes: 2 additions & 0 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ async def _call_api(
"stream": True,
},
) as event_source:
await self._assert_api_call_is_success(event_source.response)

async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
if data["type"] != "completion":
Expand Down
21 changes: 20 additions & 1 deletion ragna/assistants/_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import abc
import contextlib
import json
import os
from typing import AsyncIterator

import httpx
from httpx import Response

import ragna
from ragna.core import Assistant, EnvVarRequirement, Requirement, Source
from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement, Source


class ApiAssistant(Assistant):
Expand Down Expand Up @@ -35,3 +38,19 @@ async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[str]:
...

async def _assert_api_call_is_success(self, response: Response) -> None:
if response.is_success:
return

content = await response.aread()
with contextlib.suppress(Exception):
content = json.loads(content)

raise RagnaException(
"API call failed",
request_method=response.request.method,
request_url=str(response.request.url),
response_status_code=response.status_code,
response_content=content,
)
9 changes: 6 additions & 3 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,15 @@ async def _call_api(
"documents": self._make_source_documents(sources),
},
) as response:
if response.is_error:
raise RagnaException(status_code=response.status_code)
await self._assert_api_call_is_success(response)

async for chunk in response.aiter_lines():
event = json.loads(chunk)
if event["event_type"] == "stream-end":
break
if event["event_type"] == "COMPLETE":
break

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])

Expand Down
2 changes: 2 additions & 0 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ async def _call_api(
},
},
) as response:
await self._assert_api_call_is_success(response)

async for chunk in ijson.items(
AsyncIteratorReader(response.aiter_bytes(1024)),
"item.candidates.item.content.parts.item.text",
Expand Down
8 changes: 3 additions & 5 deletions ragna/assistants/_mosaicml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
from ragna.core import Source

from ._api import ApiAssistant

Expand Down Expand Up @@ -43,10 +43,8 @@ async def _call_api(
"parameters": {"temperature": 0.0, "max_new_tokens": max_new_tokens},
},
)
if response.is_error:
raise RagnaException(
status_code=response.status_code, response=response.json()
)
await self._assert_api_call_is_success(response)

yield cast(str, response.json()["outputs"][0]).replace(instruction, "").strip()


Expand Down
2 changes: 2 additions & 0 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ async def _call_api(
"stream": True,
},
) as event_source:
await self._assert_api_call_is_success(event_source.response)

async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
choice = data["choices"][0]
Expand Down
Empty file added tests/assistants/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions tests/assistants/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

import pytest

from ragna import assistants
from ragna._compat import anext
from ragna.assistants._api import ApiAssistant
from ragna.core import RagnaException
from tests.utils import skip_on_windows

API_ASSISTANTS = [
assistant
for assistant in assistants.__dict__.values()
if isinstance(assistant, type)
and issubclass(assistant, ApiAssistant)
and assistant is not ApiAssistant
]


@skip_on_windows
@pytest.mark.parametrize("assistant", API_ASSISTANTS)
async def test_api_call_error_smoke(mocker, assistant):
mocker.patch.dict(os.environ, {assistant._API_KEY_ENV_VAR: "SENTINEL"})

chunks = assistant().answer(prompt="?", sources=[])

with pytest.raises(RagnaException, match="API call failed"):
await anext(chunks)
7 changes: 7 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import platform

import pytest

skip_on_windows = pytest.mark.skipif(
platform.system() == "Windows", reason="Test is broken skipped on Windows"
)

0 comments on commit 1043241

Please sign in to comment.