Skip to content

Commit f83739d

Browse files
committed
test: add comprehensive retry_on_exceptions test with fail-then-succeed model
Add test_retry_on_custom_exception_with_fail_then_succeed_model that verifies the actual retry mechanism works end-to-end with a custom exception type. Uses a CustomErrorModel that fails on first call and succeeds on retry.
1 parent b4554c9 commit f83739d

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

tests/unittests/plugins/test_llm_resilience_plugin.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,56 @@ class CustomError(Exception):
218218
self.assertEqual(
219219
result.content.parts[0].text.strip(), "final response from mock"
220220
)
221+
222+
async def test_retry_on_custom_exception_with_fail_then_succeed_model(self):
223+
"""Test retry with a model that fails once then succeeds on custom exception."""
224+
225+
class MyCustomError(Exception):
226+
pass
227+
228+
class CustomErrorModel(BaseLlm):
229+
model: str = "custom-error-model"
230+
call_count: int = 0
231+
232+
@classmethod
233+
def supported_models(cls) -> list[str]:
234+
return ["custom-error-model"]
235+
236+
async def generate_content_async(
237+
self, llm_request: LlmRequest, stream: bool = False
238+
) -> AsyncGenerator[LlmResponse, None]:
239+
CustomErrorModel.call_count += 1
240+
if CustomErrorModel.call_count == 1:
241+
raise MyCustomError("Custom error!")
242+
243+
yield LlmResponse(
244+
content=types.Content(
245+
role="model",
246+
parts=[types.Part.from_text(text="Success!")],
247+
),
248+
partial=False,
249+
)
250+
251+
# Set call_count=1 to simulate the initial call already happened
252+
# (which raised the error that triggered on_model_error_callback)
253+
CustomErrorModel.call_count = 1
254+
LLMRegistry.register(CustomErrorModel)
255+
256+
agent = LlmAgent(name="agent", model="custom-error-model")
257+
invocation_context = await create_invocation_context(agent)
258+
plugin = LlmResiliencePlugin(
259+
max_retries=1,
260+
retry_on_exceptions=(MyCustomError,),
261+
)
262+
llm_request = LlmRequest(contents=[])
263+
264+
# The plugin should catch MyCustomError and retry.
265+
result = await plugin.on_model_error_callback(
266+
callback_context=invocation_context,
267+
llm_request=llm_request,
268+
error=MyCustomError(),
269+
)
270+
271+
self.assertIsNotNone(result)
272+
self.assertEqual(result.content.parts[0].text, "Success!")
273+
self.assertEqual(CustomErrorModel.call_count, 2)

0 commit comments

Comments
 (0)