diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py index a20f8b4849..98b2d9c6f8 100644 --- a/redis/asyncio/retry.py +++ b/redis/asyncio/retry.py @@ -2,18 +2,16 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, RedisError, TimeoutError - -if TYPE_CHECKING: - from redis.backoff import AbstractBackoff - +from redis.retry import AbstractRetry T = TypeVar("T") +if TYPE_CHECKING: + from redis.backoff import AbstractBackoff -class Retry: - """Retry a specific number of times after a failure""" - __slots__ = "_backoff", "_retries", "_supported_errors" +class Retry(AbstractRetry[RedisError]): + __hash__ = AbstractRetry.__hash__ def __init__( self, @@ -24,36 +22,17 @@ def __init__( TimeoutError, ), ): - """ - Initialize a `Retry` object with a `Backoff` object - that retries a maximum of `retries` times. - `retries` can be negative to retry forever. - You can specify the types of supported errors which trigger - a retry with the `supported_errors` parameter. - """ - self._backoff = backoff - self._retries = retries - self._supported_errors = supported_errors + super().__init__(backoff, retries, supported_errors) - def update_supported_errors(self, specified_errors: list): - """ - Updates the supported errors with the specified error types - """ - self._supported_errors = tuple( - set(self._supported_errors + tuple(specified_errors)) - ) - - def get_retries(self) -> int: - """ - Get the number of retries. - """ - return self._retries + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Retry): + return NotImplemented - def update_retries(self, value: int) -> None: - """ - Set the number of retries. - """ - self._retries = value + return ( + self._backoff == other._backoff + and self._retries == other._retries + and set(self._supported_errors) == set(other._supported_errors) + ) async def call_with_retry( self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] diff --git a/redis/backoff.py b/redis/backoff.py index 22a3ed0abb..6e1f68a7ba 100644 --- a/redis/backoff.py +++ b/redis/backoff.py @@ -170,7 +170,7 @@ def __hash__(self) -> int: return hash((self._base, self._cap)) def __eq__(self, other) -> bool: - if not isinstance(other, EqualJitterBackoff): + if not isinstance(other, ExponentialWithJitterBackoff): return NotImplemented return self._base == other._base and self._cap == other._cap diff --git a/redis/retry.py b/redis/retry.py index c93f34e65f..75778635e8 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,27 +1,27 @@ +import abc import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, TimeoutError T = TypeVar("T") +E = TypeVar("E", bound=Exception, covariant=True) if TYPE_CHECKING: from redis.backoff import AbstractBackoff -class Retry: +class AbstractRetry(Generic[E], abc.ABC): """Retry a specific number of times after a failure""" + _supported_errors: Tuple[Type[E], ...] + def __init__( self, backoff: "AbstractBackoff", retries: int, - supported_errors: Tuple[Type[Exception], ...] = ( - ConnectionError, - TimeoutError, - socket.timeout, - ), + supported_errors: Tuple[Type[E], ...], ): """ Initialize a `Retry` object with a `Backoff` object @@ -34,22 +34,14 @@ def __init__( self._retries = retries self._supported_errors = supported_errors + @abc.abstractmethod def __eq__(self, other: Any) -> bool: - if not isinstance(other, Retry): - return NotImplemented - - return ( - self._backoff == other._backoff - and self._retries == other._retries - and set(self._supported_errors) == set(other._supported_errors) - ) + return NotImplemented def __hash__(self) -> int: return hash((self._backoff, self._retries, frozenset(self._supported_errors))) - def update_supported_errors( - self, specified_errors: Iterable[Type[Exception]] - ) -> None: + def update_supported_errors(self, specified_errors: Iterable[Type[E]]) -> None: """ Updates the supported errors with the specified error types """ @@ -69,6 +61,32 @@ def update_retries(self, value: int) -> None: """ self._retries = value + +class Retry(AbstractRetry[Exception]): + __hash__ = AbstractRetry.__hash__ + + def __init__( + self, + backoff: "AbstractBackoff", + retries: int, + supported_errors: Tuple[Type[Exception], ...] = ( + ConnectionError, + TimeoutError, + socket.timeout, + ), + ): + super().__init__(backoff, retries, supported_errors) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Retry): + return NotImplemented + + return ( + self._backoff == other._backoff + and self._retries == other._retries + and set(self._supported_errors) == set(other._supported_errors) + ) + def call_with_retry( self, do: Callable[[], T], diff --git a/tests/test_retry.py b/tests/test_retry.py index 4f4f04caca..9c0ca65d81 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest +from redis.asyncio.retry import Retry as AsyncRetry from redis.backoff import ( AbstractBackoff, ConstantBackoff, @@ -89,6 +90,7 @@ def test_retry_on_error_retry(self, Class, retries): assert c.retry._retries == retries +@pytest.mark.parametrize("retry_class", [Retry, AsyncRetry]) @pytest.mark.parametrize( "args", [ @@ -108,8 +110,8 @@ def test_retry_on_error_retry(self, Class, retries): for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5)) ], ) -def test_retry_eq_and_hashable(args): - assert Retry(*args) == Retry(*args) +def test_retry_eq_and_hashable(retry_class, args): + assert retry_class(*args) == retry_class(*args) # create another retry object with different parameters copy = list(args) @@ -118,9 +120,19 @@ def test_retry_eq_and_hashable(args): else: copy[0] = ConstantBackoff(9000) - assert Retry(*args) != Retry(*copy) - assert Retry(*copy) != Retry(*args) - assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2 + assert retry_class(*args) != retry_class(*copy) + assert retry_class(*copy) != retry_class(*args) + assert ( + len( + { + retry_class(*args), + retry_class(*args), + retry_class(*copy), + retry_class(*copy), + } + ) + == 2 + ) class TestRetry: