diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py index e91ca176..5bb7c21d 100644 --- a/src/anthropic/lib/bedrock/_client.py +++ b/src/anthropic/lib/bedrock/_client.py @@ -93,6 +93,27 @@ def _infer_region() -> str: class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]): + @override + def _should_retry(self, response: httpx.Response) -> bool: + if super()._should_retry(response): + return True + + if response.status_code == 400: + error_type = response.headers.get("x-amzn-errortype", "") + if any( + exc in error_type + for exc in ( + "ThrottlingException", + "TooManyRequestsException", + "ModelTimeoutException", + "ServiceUnavailableException", + ) + ): + log.debug("Retrying due to Bedrock transient error: %s", error_type) + return True + + return False + @override def _make_status_error( self, diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 2bfb458a..0827ed79 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -8,7 +8,7 @@ import pytest from respx import MockRouter -from anthropic import AnthropicBedrock, AsyncAnthropicBedrock +from anthropic import BadRequestError, AnthropicBedrock, AsyncAnthropicBedrock from anthropic.lib.bedrock._stream_decoder import _chunk_bytes_to_sse sync_client = AnthropicBedrock( @@ -319,3 +319,94 @@ def test_async_copy_x_stainless_helper_header_appends() -> None: client = async_client.with_options(default_headers={"x-stainless-helper": "parent"}) copied = client.with_options(default_headers={"x-stainless-helper": "child"}) assert copied.default_headers["x-stainless-helper"] == "parent, child" + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.respx() +def test_retries_on_bedrock_throttling_error(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock( + side_effect=[ + httpx.Response( + 400, + json={"message": "Too many requests, please wait before trying again."}, + headers={"x-amzn-errortype": "ThrottlingException", "retry-after-ms": "10"}, + ), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + sync_client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.respx() +@pytest.mark.asyncio() +async def test_retries_on_bedrock_throttling_error_async(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock( + side_effect=[ + httpx.Response( + 400, + json={"message": "Too many requests, please wait before trying again."}, + headers={"x-amzn-errortype": "ThrottlingException", "retry-after-ms": "10"}, + ), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + await async_client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.respx() +def test_no_retry_on_bedrock_validation_error(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock( + side_effect=[ + httpx.Response( + 400, + json={"message": "Invalid input"}, + headers={"x-amzn-errortype": "ValidationException"}, + ), + ] + ) + + with pytest.raises(BadRequestError): + sync_client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 1