Skip to content

Commit 7632cda

Browse files
authored
feat: added headers to the DIAL exception class (#192)
1 parent 45681f3 commit 7632cda

File tree

5 files changed

+87
-56
lines changed

5 files changed

+87
-56
lines changed

aidial_sdk/_errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def fastapi_exception_handler(request: Request, exc: Exception) -> JSONResponse:
2323
return JSONResponse(
2424
status_code=exc.status_code,
2525
content=exc.detail,
26+
headers=exc.headers,
2627
)
2728

2829

aidial_sdk/exceptions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import warnings
33
from http import HTTPStatus
4-
from typing import Optional
4+
from typing import Dict, Optional
55

66
from fastapi import HTTPException as FastAPIException
77
from fastapi.responses import JSONResponse
@@ -18,6 +18,7 @@ def __init__(
1818
param: Optional[str] = None,
1919
code: Optional[str] = None,
2020
display_message: Optional[str] = None,
21+
headers: Optional[Dict[str, str]] = None,
2122
) -> None:
2223
status_code = int(status_code)
2324

@@ -27,8 +28,11 @@ def __init__(
2728
self.param = param
2829
self.code = code or str(status_code)
2930
self.display_message = display_message
31+
self.headers = headers
3032

3133
def __repr__(self):
34+
# headers field is omitted deliberately
35+
# since it may contain sensitive information
3236
return (
3337
"%s(message=%r, status_code=%r, type=%r, param=%r, code=%r, display_message=%r)"
3438
% (
@@ -59,12 +63,14 @@ def to_fastapi_response(self) -> JSONResponse:
5963
return JSONResponse(
6064
status_code=self.status_code,
6165
content=self.json_error(),
66+
headers=self.headers,
6267
)
6368

6469
def to_fastapi_exception(self) -> FastAPIException:
6570
return FastAPIException(
6671
status_code=self.status_code,
6772
detail=self.json_error(),
73+
headers=self.headers,
6874
)
6975

7076

tests/applications/broken_immediately.py renamed to tests/applications/broken.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
55

66

7-
def raise_exception(exception_type: str):
7+
def _raise_exception(exception_type: str):
88
if exception_type == "sdk_exception":
99
raise DIALException("Test error", 503)
1010
elif exception_type == "fastapi_exception":
@@ -15,16 +15,38 @@ def raise_exception(exception_type: str):
1515
return 1 / 0
1616
elif exception_type == "sdk_exception_with_display_message":
1717
raise DIALException("Test error", 503, display_message="I'm broken")
18+
elif exception_type == "sdk_exception_with_headers":
19+
raise DIALException(
20+
"Too many requests", 429, headers={"Retry-After": "42"}
21+
)
1822
else:
1923
raise DIALException("Unexpected error")
2024

2125

22-
class BrokenApplication(ChatCompletion):
26+
class ImmediatelyBrokenApplication(ChatCompletion):
2327
"""
2428
Application which breaks immediately after receiving a request.
2529
"""
2630

2731
async def chat_completion(
2832
self, request: Request, response: Response
2933
) -> None:
30-
raise_exception(request.messages[0].text())
34+
_raise_exception(request.messages[0].text())
35+
36+
37+
class RuntimeBrokenApplication(ChatCompletion):
38+
"""
39+
Application which breaks after producing some output.
40+
"""
41+
42+
async def chat_completion(
43+
self, request: Request, response: Response
44+
) -> None:
45+
response.set_response_id("test_id")
46+
response.set_created(0)
47+
48+
with response.create_single_choice() as choice:
49+
choice.append_content("Test content")
50+
await response.aflush()
51+
52+
_raise_exception(request.messages[0].text())

tests/applications/broken_in_runtime.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

tests/test_errors.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import dataclasses
12
import json
3+
from typing import Any, Dict, List
24

35
import pytest
46
from starlette.testclient import TestClient
57

68
from aidial_sdk import DIALApp
7-
from tests.applications.broken_immediately import BrokenApplication
8-
from tests.applications.broken_in_runtime import RuntimeBrokenApplication
9+
from tests.applications.broken import (
10+
ImmediatelyBrokenApplication,
11+
RuntimeBrokenApplication,
12+
)
913
from tests.applications.noop import NoopApplication
1014

1115
DEFAULT_RUNTIME_ERROR = {
@@ -24,11 +28,20 @@
2428
}
2529
}
2630

27-
error_testdata = [
28-
("fastapi_exception", 500, DEFAULT_RUNTIME_ERROR),
29-
("value_error_exception", 500, DEFAULT_RUNTIME_ERROR),
30-
("zero_division_exception", 500, DEFAULT_RUNTIME_ERROR),
31-
(
31+
32+
@dataclasses.dataclass
33+
class ErrorTestCase:
34+
content: Any
35+
response_code: int
36+
response_error: dict
37+
response_headers: Dict[str, str] = dataclasses.field(default_factory=dict)
38+
39+
40+
error_testcases: List[ErrorTestCase] = [
41+
ErrorTestCase("fastapi_exception", 500, DEFAULT_RUNTIME_ERROR),
42+
ErrorTestCase("value_error_exception", 500, DEFAULT_RUNTIME_ERROR),
43+
ErrorTestCase("zero_division_exception", 500, DEFAULT_RUNTIME_ERROR),
44+
ErrorTestCase(
3245
"sdk_exception",
3346
503,
3447
{
@@ -39,7 +52,7 @@
3952
}
4053
},
4154
),
42-
(
55+
ErrorTestCase(
4356
"sdk_exception_with_display_message",
4457
503,
4558
{
@@ -51,7 +64,7 @@
5164
}
5265
},
5366
),
54-
(
67+
ErrorTestCase(
5568
None,
5669
400,
5770
{
@@ -62,7 +75,7 @@
6275
}
6376
},
6477
),
65-
(
78+
ErrorTestCase(
6679
[{"type": "text", "text": "hello"}],
6780
400,
6881
{
@@ -73,57 +86,66 @@
7386
}
7487
},
7588
),
89+
ErrorTestCase(
90+
"sdk_exception_with_headers",
91+
429,
92+
{
93+
"error": {
94+
"message": "Too many requests",
95+
"type": "runtime_error",
96+
"code": "429",
97+
}
98+
},
99+
{"Retry-after": "42"},
100+
),
76101
]
77102

78103

79-
@pytest.mark.parametrize(
80-
"type, response_status_code, response_content", error_testdata
81-
)
82-
def test_error(type, response_status_code, response_content):
104+
@pytest.mark.parametrize("test_case", error_testcases)
105+
def test_error(test_case: ErrorTestCase):
83106
dial_app = DIALApp()
84-
dial_app.add_chat_completion("test_app", BrokenApplication())
107+
dial_app.add_chat_completion("test_app", ImmediatelyBrokenApplication())
85108

86109
test_app = TestClient(dial_app)
87110

88111
response = test_app.post(
89112
"/openai/deployments/test_app/chat/completions",
90113
json={
91-
"messages": [{"role": "user", "content": type}],
114+
"messages": [{"role": "user", "content": test_case.content}],
92115
"stream": False,
93116
},
94117
headers={"Api-Key": "TEST_API_KEY"},
95118
)
96119

97-
assert response.status_code == response_status_code
98-
assert response.json() == response_content
120+
assert response.status_code == test_case.response_code
121+
assert response.json() == test_case.response_error
99122

123+
for k, v in test_case.response_headers.items():
124+
assert response.headers.get(k) == v
100125

101-
@pytest.mark.parametrize(
102-
"type, response_status_code, response_content", error_testdata
103-
)
104-
def test_streaming_error(type, response_status_code, response_content):
126+
127+
@pytest.mark.parametrize("test_case", error_testcases)
128+
def test_streaming_error(test_case: ErrorTestCase):
105129
dial_app = DIALApp()
106-
dial_app.add_chat_completion("test_app", BrokenApplication())
130+
dial_app.add_chat_completion("test_app", ImmediatelyBrokenApplication())
107131

108132
test_app = TestClient(dial_app)
109133

110134
response = test_app.post(
111135
"/openai/deployments/test_app/chat/completions",
112136
json={
113-
"messages": [{"role": "user", "content": type}],
137+
"messages": [{"role": "user", "content": test_case.content}],
114138
"stream": True,
115139
},
116140
headers={"Api-Key": "TEST_API_KEY"},
117141
)
118142

119-
assert response.status_code == response_status_code
120-
assert response.json() == response_content
143+
assert response.status_code == test_case.response_code
144+
assert response.json() == test_case.response_error
121145

122146

123-
@pytest.mark.parametrize(
124-
"type, response_status_code, response_content", error_testdata
125-
)
126-
def test_runtime_streaming_error(type, response_status_code, response_content):
147+
@pytest.mark.parametrize("test_case", error_testcases)
148+
def test_runtime_streaming_error(test_case: ErrorTestCase):
127149
dial_app = DIALApp()
128150
dial_app.add_chat_completion("test_app", RuntimeBrokenApplication())
129151

@@ -132,7 +154,7 @@ def test_runtime_streaming_error(type, response_status_code, response_content):
132154
response = test_app.post(
133155
"/openai/deployments/test_app/chat/completions",
134156
json={
135-
"messages": [{"role": "user", "content": type}],
157+
"messages": [{"role": "user", "content": test_case.content}],
136158
"stream": True,
137159
},
138160
headers={"Api-Key": "TEST_API_KEY"},
@@ -183,7 +205,7 @@ def test_runtime_streaming_error(type, response_status_code, response_content):
183205
"object": "chat.completion.chunk",
184206
}
185207
elif index == 6:
186-
assert json.loads(data) == response_content
208+
assert json.loads(data) == test_case.response_error
187209
elif index == 8:
188210
assert data == "[DONE]"
189211

0 commit comments

Comments
 (0)