Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/anthropic/lib/bedrock/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,27 @@ def _infer_region() -> str:


class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
@override
def _should_retry(self, response: httpx.Response) -> bool:
if super()._should_retry(response):
return True

if response.status_code == 400:
error_type = response.headers.get("x-amzn-errortype", "")
if any(
exc in error_type
for exc in (
"ThrottlingException",
"TooManyRequestsException",
"ModelTimeoutException",
"ServiceUnavailableException",
)
):
log.debug("Retrying due to Bedrock transient error: %s", error_type)
return True

return False

@override
def _make_status_error(
self,
Expand Down
93 changes: 92 additions & 1 deletion tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from respx import MockRouter

from anthropic import AnthropicBedrock, AsyncAnthropicBedrock
from anthropic import BadRequestError, AnthropicBedrock, AsyncAnthropicBedrock
from anthropic.lib.bedrock._stream_decoder import _chunk_bytes_to_sse

sync_client = AnthropicBedrock(
Expand Down Expand Up @@ -319,3 +319,94 @@ def test_async_copy_x_stainless_helper_header_appends() -> None:
client = async_client.with_options(default_headers={"x-stainless-helper": "parent"})
copied = client.with_options(default_headers={"x-stainless-helper": "child"})
assert copied.default_headers["x-stainless-helper"] == "parent, child"


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
def test_retries_on_bedrock_throttling_error(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
side_effect=[
httpx.Response(
400,
json={"message": "Too many requests, please wait before trying again."},
headers={"x-amzn-errortype": "ThrottlingException", "retry-after-ms": "10"},
),
httpx.Response(200, json={"foo": "bar"}),
]
)

sync_client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
@pytest.mark.asyncio()
async def test_retries_on_bedrock_throttling_error_async(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
side_effect=[
httpx.Response(
400,
json={"message": "Too many requests, please wait before trying again."},
headers={"x-amzn-errortype": "ThrottlingException", "retry-after-ms": "10"},
),
httpx.Response(200, json={"foo": "bar"}),
]
)

await async_client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
def test_no_retry_on_bedrock_validation_error(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
side_effect=[
httpx.Response(
400,
json={"message": "Invalid input"},
headers={"x-amzn-errortype": "ValidationException"},
),
]
)

with pytest.raises(BadRequestError):
sync_client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 1