1
+ import dataclasses
1
2
import json
3
+ from typing import Any , Dict , List
2
4
3
5
import pytest
4
6
from starlette .testclient import TestClient
5
7
6
8
from aidial_sdk import DIALApp
7
- from tests .applications .broken_immediately import BrokenApplication
8
- from tests .applications .broken_in_runtime import RuntimeBrokenApplication
9
+ from tests .applications .broken import (
10
+ ImmediatelyBrokenApplication ,
11
+ RuntimeBrokenApplication ,
12
+ )
9
13
from tests .applications .noop import NoopApplication
10
14
11
15
DEFAULT_RUNTIME_ERROR = {
24
28
}
25
29
}
26
30
27
- error_testdata = [
28
- ("fastapi_exception" , 500 , DEFAULT_RUNTIME_ERROR ),
29
- ("value_error_exception" , 500 , DEFAULT_RUNTIME_ERROR ),
30
- ("zero_division_exception" , 500 , DEFAULT_RUNTIME_ERROR ),
31
- (
31
+
32
+ @dataclasses .dataclass
33
+ class ErrorTestCase :
34
+ content : Any
35
+ response_code : int
36
+ response_error : dict
37
+ response_headers : Dict [str , str ] = dataclasses .field (default_factory = dict )
38
+
39
+
40
+ error_testcases : List [ErrorTestCase ] = [
41
+ ErrorTestCase ("fastapi_exception" , 500 , DEFAULT_RUNTIME_ERROR ),
42
+ ErrorTestCase ("value_error_exception" , 500 , DEFAULT_RUNTIME_ERROR ),
43
+ ErrorTestCase ("zero_division_exception" , 500 , DEFAULT_RUNTIME_ERROR ),
44
+ ErrorTestCase (
32
45
"sdk_exception" ,
33
46
503 ,
34
47
{
39
52
}
40
53
},
41
54
),
42
- (
55
+ ErrorTestCase (
43
56
"sdk_exception_with_display_message" ,
44
57
503 ,
45
58
{
51
64
}
52
65
},
53
66
),
54
- (
67
+ ErrorTestCase (
55
68
None ,
56
69
400 ,
57
70
{
62
75
}
63
76
},
64
77
),
65
- (
78
+ ErrorTestCase (
66
79
[{"type" : "text" , "text" : "hello" }],
67
80
400 ,
68
81
{
73
86
}
74
87
},
75
88
),
89
+ ErrorTestCase (
90
+ "sdk_exception_with_headers" ,
91
+ 429 ,
92
+ {
93
+ "error" : {
94
+ "message" : "Too many requests" ,
95
+ "type" : "runtime_error" ,
96
+ "code" : "429" ,
97
+ }
98
+ },
99
+ {"Retry-after" : "42" },
100
+ ),
76
101
]
77
102
78
103
79
- @pytest .mark .parametrize (
80
- "type, response_status_code, response_content" , error_testdata
81
- )
82
- def test_error (type , response_status_code , response_content ):
104
+ @pytest .mark .parametrize ("test_case" , error_testcases )
105
+ def test_error (test_case : ErrorTestCase ):
83
106
dial_app = DIALApp ()
84
- dial_app .add_chat_completion ("test_app" , BrokenApplication ())
107
+ dial_app .add_chat_completion ("test_app" , ImmediatelyBrokenApplication ())
85
108
86
109
test_app = TestClient (dial_app )
87
110
88
111
response = test_app .post (
89
112
"/openai/deployments/test_app/chat/completions" ,
90
113
json = {
91
- "messages" : [{"role" : "user" , "content" : type }],
114
+ "messages" : [{"role" : "user" , "content" : test_case . content }],
92
115
"stream" : False ,
93
116
},
94
117
headers = {"Api-Key" : "TEST_API_KEY" },
95
118
)
96
119
97
- assert response .status_code == response_status_code
98
- assert response .json () == response_content
120
+ assert response .status_code == test_case . response_code
121
+ assert response .json () == test_case . response_error
99
122
123
+ for k , v in test_case .response_headers .items ():
124
+ assert response .headers .get (k ) == v
100
125
101
- @pytest .mark .parametrize (
102
- "type, response_status_code, response_content" , error_testdata
103
- )
104
- def test_streaming_error (type , response_status_code , response_content ):
126
+
127
+ @pytest .mark .parametrize ("test_case" , error_testcases )
128
+ def test_streaming_error (test_case : ErrorTestCase ):
105
129
dial_app = DIALApp ()
106
- dial_app .add_chat_completion ("test_app" , BrokenApplication ())
130
+ dial_app .add_chat_completion ("test_app" , ImmediatelyBrokenApplication ())
107
131
108
132
test_app = TestClient (dial_app )
109
133
110
134
response = test_app .post (
111
135
"/openai/deployments/test_app/chat/completions" ,
112
136
json = {
113
- "messages" : [{"role" : "user" , "content" : type }],
137
+ "messages" : [{"role" : "user" , "content" : test_case . content }],
114
138
"stream" : True ,
115
139
},
116
140
headers = {"Api-Key" : "TEST_API_KEY" },
117
141
)
118
142
119
- assert response .status_code == response_status_code
120
- assert response .json () == response_content
143
+ assert response .status_code == test_case . response_code
144
+ assert response .json () == test_case . response_error
121
145
122
146
123
- @pytest .mark .parametrize (
124
- "type, response_status_code, response_content" , error_testdata
125
- )
126
- def test_runtime_streaming_error (type , response_status_code , response_content ):
147
+ @pytest .mark .parametrize ("test_case" , error_testcases )
148
+ def test_runtime_streaming_error (test_case : ErrorTestCase ):
127
149
dial_app = DIALApp ()
128
150
dial_app .add_chat_completion ("test_app" , RuntimeBrokenApplication ())
129
151
@@ -132,7 +154,7 @@ def test_runtime_streaming_error(type, response_status_code, response_content):
132
154
response = test_app .post (
133
155
"/openai/deployments/test_app/chat/completions" ,
134
156
json = {
135
- "messages" : [{"role" : "user" , "content" : type }],
157
+ "messages" : [{"role" : "user" , "content" : test_case . content }],
136
158
"stream" : True ,
137
159
},
138
160
headers = {"Api-Key" : "TEST_API_KEY" },
@@ -183,7 +205,7 @@ def test_runtime_streaming_error(type, response_status_code, response_content):
183
205
"object" : "chat.completion.chunk" ,
184
206
}
185
207
elif index == 6 :
186
- assert json .loads (data ) == response_content
208
+ assert json .loads (data ) == test_case . response_error
187
209
elif index == 8 :
188
210
assert data == "[DONE]"
189
211
0 commit comments