-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathauto_retry.py
More file actions
252 lines (194 loc) · 8.9 KB
/
auto_retry.py
File metadata and controls
252 lines (194 loc) · 8.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "agent-framework",
# "tenacity",
# ]
# ///
# Run with any PEP 723 compatible runner, e.g.:
# uv run samples/02-agents/auto_retry.py
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar, cast
from agent_framework import ChatContext, ChatMiddleware, SupportsChatGetResponse, chat_middleware
from agent_framework.azure import AzureOpenAIChatClient
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from openai import RateLimitError
from tenacity import (
AsyncRetrying,
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
# Load environment variables from .env file
load_dotenv()
"""
Auto-Retry Rate Limiting Sample
Every model inference API enforces rate limits, so production agents need retry logic
to handle 429 responses gracefully. This sample shows two ways to add automatic retry
using the `tenacity` library, keeping your application code free of boilerplate.
Approach 1 – Class decorator
Apply a class decorator to any client type implementing
SupportsChatGetResponse. The decorator patches get_response() with retry
behavior. Non-streaming responses are retried; streaming is returned as-is
(streaming retry requires more delicate handling).
Approach 2 – Chat middleware
Register middleware on the agent that catches RateLimitError raised inside
call_next() and retries the entire request pipeline. Two styles are shown:
a) Class-based middleware (ChatMiddleware subclass)
b) Function-based middleware (@chat_middleware decorator)
Both approaches use the same tenacity primitives:
- stop_after_attempt – cap the total number of tries
- wait_exponential – exponential back-off between retries
- retry_if_exception_type(RateLimitError) – only retry on 429 errors
- before_sleep_log – log each retry attempt at WARNING level
"""
logger = logging.getLogger(__name__)
RETRY_ATTEMPTS = 3
# =============================================================================
# Approach 1: Class decorator
# =============================================================================
ChatClientT = TypeVar("ChatClientT", bound=SupportsChatGetResponse[Any])
def with_rate_limit_retry(*, retry_attempts: int = RETRY_ATTEMPTS) -> Callable[[type[ChatClientT]], type[ChatClientT]]:
"""Class decorator that adds non-streaming retry behavior to get_response()."""
def decorator(client_cls: type[ChatClientT]) -> type[ChatClientT]:
original_get_response = client_cls.get_response
def get_response_with_retry(self, *args, **kwargs): # type: ignore[no-untyped-def]
stream = kwargs.get("stream", False)
if stream:
# Streaming retry is more complex; fall back to the original behaviour.
return original_get_response(self, *args, **kwargs)
async def _with_retry():
async for attempt in AsyncRetrying(
stop=stop_after_attempt(retry_attempts),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(RateLimitError),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
):
with attempt:
return await original_get_response(self, *args, **kwargs)
return None
return _with_retry()
client_cls.get_response = cast(Any, get_response_with_retry)
return client_cls
return decorator
@with_rate_limit_retry()
class RetryingAzureOpenAIChatClient(AzureOpenAIChatClient):
"""Azure OpenAI Chat client with class-decorator-based retry behavior."""
# =============================================================================
# Approach 2a: Class-based chat middleware
# =============================================================================
class RateLimitRetryMiddleware(ChatMiddleware):
"""Chat middleware that retries a single model-call pipeline on rate limit errors.
Register this middleware on an agent (or at the run level) to automatically
retry any chat-model call that raises RateLimitError. In tool-loop scenarios,
the middleware applies independently to each inner model call.
"""
def __init__(self, *, max_attempts: int = RETRY_ATTEMPTS) -> None:
"""Initialize with the maximum number of retry attempts."""
self.max_attempts = max_attempts
async def process(
self,
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Retry call_next() on rate limit errors with exponential back-off."""
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.max_attempts),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(RateLimitError),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
):
with attempt:
await call_next()
# =============================================================================
# Approach 2b: Function-based chat middleware
# =============================================================================
@chat_middleware
async def rate_limit_retry_middleware(
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function-based chat middleware that retries on rate limit errors.
Wrap call_next() with a tenacity @retry decorator so any RateLimitError
raised during a single model call triggers an automatic retry with exponential
back-off. In tool-loop scenarios, the middleware applies independently to
each inner model call.
"""
@retry(
stop=stop_after_attempt(RETRY_ATTEMPTS),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(RateLimitError),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
)
async def _call_next_with_retry() -> None:
await call_next()
await _call_next_with_retry()
# =============================================================================
# Demo
# =============================================================================
async def class_decorator_example() -> None:
"""Demonstrate Approach 1: class decorator on a chat client type."""
print("\n" + "=" * 60)
print("Approach 1: Class decorator (applied to client type)")
print("=" * 60)
# For authentication, run `az login` command in terminal or replace
# AzureCliCredential with your preferred authentication option.
agent = RetryingAzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
instructions="You are a helpful assistant.",
)
query = "Say hello!"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
async def class_based_middleware_example() -> None:
"""Demonstrate Approach 2a: class-based chat middleware."""
print("\n" + "=" * 60)
print("Approach 2a: Class-based chat middleware")
print("=" * 60)
# For authentication, run `az login` command in terminal or replace
# AzureCliCredential with your preferred authentication option.
agent = AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
instructions="You are a helpful assistant.",
middleware=[RateLimitRetryMiddleware(max_attempts=3)],
)
query = "Say hello!"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
async def function_based_middleware_example() -> None:
"""Demonstrate Approach 2b: function-based chat middleware."""
print("\n" + "=" * 60)
print("Approach 2b: Function-based chat middleware")
print("=" * 60)
# For authentication, run `az login` command in terminal or replace
# AzureCliCredential with your preferred authentication option.
agent = AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
instructions="You are a helpful assistant.",
middleware=[rate_limit_retry_middleware],
)
query = "Say hello!"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
async def main() -> None:
"""Run all auto-retry examples."""
print("=== Auto-Retry Rate Limiting Sample ===")
print(
"Demonstrates two approaches for automatic retry on rate limit (429) errors.\n"
"Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME (and optionally\n"
"AZURE_OPENAI_API_KEY) before running, or populate a .env file."
)
await class_decorator_example()
await class_based_middleware_example()
await function_based_middleware_example()
if __name__ == "__main__":
asyncio.run(main())