diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 31f3df5f1..805f5ce20 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -24,6 +24,7 @@ CLIJSONDecodeError, CLINotFoundError, ProcessError, + RateLimitError, ) from ._internal.session_import import import_session_to_store from ._internal.session_mutations import ( @@ -666,4 +667,5 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "CLINotFoundError", "ProcessError", "CLIJSONDecodeError", + "RateLimitError", ] diff --git a/src/claude_agent_sdk/_errors.py b/src/claude_agent_sdk/_errors.py index c86bf235c..603b6cc5b 100644 --- a/src/claude_agent_sdk/_errors.py +++ b/src/claude_agent_sdk/_errors.py @@ -54,3 +54,15 @@ class MessageParseError(ClaudeSDKError): def __init__(self, message: str, data: dict[str, Any] | None = None): self.data = data super().__init__(message) + + +class RateLimitError(ClaudeSDKError): + """Raised when the API rate limit is exceeded (HTTP 429).""" + + def __init__( + self, message: str = "Rate limit exceeded", retry_after: int | None = None + ): + self.retry_after = retry_after + if retry_after is not None: + message = f"{message} (retry after {retry_after}s)" + super().__init__(message) diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 7a4f8a447..0081f3c92 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -39,6 +39,11 @@ logger = logging.getLogger(__name__) +def _is_rate_limit_error(text: str) -> bool: + lower = text.lower() + return "429" in lower or "rate limit" in lower or "too many requests" in lower + + def _convert_hook_output_for_cli(hook_output: dict[str, Any]) -> dict[str, Any]: """Convert Python-safe field names to CLI-expected field names. @@ -350,7 +355,12 @@ async def _read_messages(self) -> None: error_text = str(e) logger.error(f"Fatal error in message reader: {e}") # Put error in stream so iterators can handle it - await self._message_send.send({"type": "error", "error": error_text}) + if _is_rate_limit_error(error_text or str(e)): + await self._message_send.send( + {"type": "error", "error": error_text, "is_rate_limit": True} + ) + else: + await self._message_send.send({"type": "error", "error": error_text}) finally: # Flush any remaining transcript mirror entries before closing so # an early stdout EOF or transport error doesn't drop entries @@ -849,6 +859,10 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: if message.get("type") == "end": break elif message.get("type") == "error": + if message.get("is_rate_limit"): + from .._errors import RateLimitError + + raise RateLimitError(message.get("error", "Rate limit exceeded")) raise Exception(message.get("error", "Unknown error")) yield message diff --git a/tests/test_errors.py b/tests/test_errors.py index 9490d075b..3229a6e02 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -6,7 +6,9 @@ CLIJSONDecodeError, CLINotFoundError, ProcessError, + RateLimitError, ) +from claude_agent_sdk._internal.query import _is_rate_limit_error class TestErrorTypes: @@ -50,3 +52,61 @@ def test_json_decode_error(self): assert error.line == "{invalid json}" assert error.original_error == e assert "Failed to decode JSON" in str(error) + + def test_rate_limit_error_is_subclass(self): + """Test RateLimitError is a subclass of ClaudeSDKError.""" + error = RateLimitError() + assert isinstance(error, ClaudeSDKError) + assert isinstance(error, Exception) + + def test_rate_limit_error_default_message(self): + """Test RateLimitError default message.""" + error = RateLimitError() + assert "Rate limit exceeded" in str(error) + assert error.retry_after is None + + def test_rate_limit_error_custom_message(self): + """Test RateLimitError with custom message.""" + error = RateLimitError("Custom rate limit message") + assert "Custom rate limit message" in str(error) + + def test_rate_limit_error_retry_after(self): + """Test RateLimitError with retry_after stores value and appends to message.""" + error = RateLimitError(retry_after=60) + assert error.retry_after == 60 + assert "retry after 60s" in str(error) + + def test_rate_limit_error_retry_after_with_message(self): + """Test RateLimitError with both message and retry_after.""" + error = RateLimitError("Too many requests", retry_after=30) + assert error.retry_after == 30 + assert "Too many requests" in str(error) + assert "retry after 30s" in str(error) + + +class TestIsRateLimitError: + """Test the _is_rate_limit_error helper.""" + + def test_detects_429(self): + assert _is_rate_limit_error("HTTP error 429") is True + + def test_detects_rate_limit(self): + assert _is_rate_limit_error("rate limit exceeded") is True + + def test_detects_rate_limit_mixed_case(self): + assert _is_rate_limit_error("Rate Limit Exceeded") is True + + def test_detects_too_many_requests(self): + assert _is_rate_limit_error("Too Many Requests") is True + + def test_detects_too_many_requests_lowercase(self): + assert _is_rate_limit_error("too many requests") is True + + def test_returns_false_for_other_errors(self): + assert _is_rate_limit_error("connection refused") is False + + def test_returns_false_for_empty_string(self): + assert _is_rate_limit_error("") is False + + def test_returns_false_for_generic_error(self): + assert _is_rate_limit_error("Unknown error occurred") is False