From 5df9a5b6b6b2c18f31bd38a652206c8886ac5521 Mon Sep 17 00:00:00 2001 From: Kevin Johnson Date: Tue, 20 Aug 2024 16:30:49 -0400 Subject: [PATCH 1/3] Allows cluster mode connections to wait for free connection when at max. --- redis/asyncio/cluster.py | 40 +++++++++++++++++++++--------- tests/test_asyncio/test_cluster.py | 28 +++++++++++++++++++++ 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 40b2948a7f..714f3c795b 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -65,9 +65,9 @@ RedisClusterException, ResponseError, SlotNotCoveredError, - TimeoutError, TryAgainError, ) +from redis.exceptions import TimeoutError as RedisTimeoutError from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( deprecated_function, @@ -264,6 +264,7 @@ def __init__( socket_timeout: Optional[float] = None, retry: Optional["Retry"] = None, retry_on_error: Optional[List[Type[Exception]]] = None, + wait_for_connections: bool = False, # SSL related kwargs ssl: bool = False, ssl_ca_certs: Optional[str] = None, @@ -326,6 +327,7 @@ def __init__( "socket_timeout": socket_timeout, "retry": retry, "protocol": protocol, + "wait_for_connections": wait_for_connections, # Client cache related kwargs "cache_enabled": cache_enabled, "client_cache": client_cache, @@ -364,7 +366,7 @@ def __init__( ) if not retry_on_error: # Default errors for retrying - retry_on_error = [ConnectionError, TimeoutError] + retry_on_error = [ConnectionError, RedisTimeoutError] self.retry.update_supported_errors(retry_on_error) kwargs.update({"retry": self.retry}) @@ -800,7 +802,7 @@ async def _execute_command( return await target_node.execute_command(*args, **kwargs) except (BusyLoadingError, MaxConnectionsError): raise - except (ConnectionError, TimeoutError): + except (ConnectionError, RedisTimeoutError): # Connection retries are being handled in the node's # Retry object. # Remove the failed node from the startup nodes before we try @@ -962,6 +964,7 @@ class ClusterNode: __slots__ = ( "_connections", "_free", + "acquire_connection_timeout", "connection_class", "connection_kwargs", "host", @@ -970,6 +973,7 @@ class ClusterNode: "port", "response_callbacks", "server_type", + "wait_for_connections", ) def __init__( @@ -980,6 +984,7 @@ def __init__( *, max_connections: int = 2**31, connection_class: Type[Connection] = Connection, + wait_for_connections: bool = False, **connection_kwargs: Any, ) -> None: if host == "localhost": @@ -996,9 +1001,11 @@ def __init__( self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) + self.acquire_connection_timeout = connection_kwargs.get('socket_timeout', 30) self._connections: List[Connection] = [] - self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) + self._free: asyncio.Queue[Connection] = asyncio.Queue() + self.wait_for_connections = wait_for_connections def __repr__(self) -> str: return ( @@ -1039,14 +1046,23 @@ async def disconnect(self) -> None: if exc: raise exc - def acquire_connection(self) -> Connection: + async def acquire_connection(self) -> Connection: try: - return self._free.popleft() - except IndexError: + return self._free.get_nowait() + except asyncio.QueueEmpty: if len(self._connections) < self.max_connections: connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection + elif self.wait_for_connections: + try: + connection = await asyncio.wait_for( + self._free.get(), + self.acquire_connection_timeout + ) + return connection + except TimeoutError: + raise RedisTimeoutError("Timeout reached waiting for a free connection") raise MaxConnectionsError() @@ -1075,12 +1091,12 @@ async def parse_response( async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection - connection = self.acquire_connection() + connection = await self.acquire_connection() keys = kwargs.pop("keys", None) response_from_cache = await connection._get_from_local_cache(args) if response_from_cache is not None: - self._free.append(connection) + await self._free.put(connection) return response_from_cache else: # Execute command @@ -1094,11 +1110,11 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any: return response finally: # Release connection - self._free.append(connection) + await self._free.put(connection) async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection - connection = self.acquire_connection() + connection = await self.acquire_connection() # Execute command await connection.send_packed_command( @@ -1117,7 +1133,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: ret = True # Release connection - self._free.append(connection) + await self._free.put(connection) return ret diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index c16272bb5b..804ebc5580 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -464,6 +464,34 @@ async def read_response_mocked(*args: Any, **kwargs: Any) -> None: await rc.aclose() + async def test_max_connections_waited( + self, create_redis: Callable[..., RedisCluster] + ) -> None: + rc = await create_redis( + cls=RedisCluster, + max_connections=10, + wait_for_connections=True + ) + for node in rc.get_nodes(): + assert node.max_connections == 10 + + with mock.patch.object(Connection, "read_response") as read_response: + + async def read_response_mocked(*args: Any, **kwargs: Any) -> None: + await asyncio.sleep(1) + + read_response.side_effect = read_response_mocked + + await asyncio.gather( + *( + rc.ping(target_nodes=RedisCluster.DEFAULT_NODE) + for _ in range(20) + ) + ) + + assert len(rc.get_default_node()._connections) == 10 + await rc.aclose() + async def test_execute_command_errors(self, r: RedisCluster) -> None: """ Test that if no key is provided then exception should be raised. From 422d9c178bc36c9274c92b28ff8e441850f66993 Mon Sep 17 00:00:00 2001 From: Kevin Johnson Date: Tue, 20 Aug 2024 16:43:33 -0400 Subject: [PATCH 2/3] Linter fixes --- redis/asyncio/cluster.py | 13 ++++++------- tests/test_asyncio/test_cluster.py | 10 ++-------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 714f3c795b..7af155436f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1,5 +1,4 @@ import asyncio -import collections import random import socket import ssl @@ -7,7 +6,6 @@ from typing import ( Any, Callable, - Deque, Dict, Generator, List, @@ -65,9 +63,9 @@ RedisClusterException, ResponseError, SlotNotCoveredError, - TryAgainError, ) from redis.exceptions import TimeoutError as RedisTimeoutError +from redis.exceptions import TryAgainError from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( deprecated_function, @@ -1001,7 +999,7 @@ def __init__( self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) - self.acquire_connection_timeout = connection_kwargs.get('socket_timeout', 30) + self.acquire_connection_timeout = connection_kwargs.get("socket_timeout", 30) self._connections: List[Connection] = [] self._free: asyncio.Queue[Connection] = asyncio.Queue() @@ -1057,12 +1055,13 @@ async def acquire_connection(self) -> Connection: elif self.wait_for_connections: try: connection = await asyncio.wait_for( - self._free.get(), - self.acquire_connection_timeout + self._free.get(), self.acquire_connection_timeout ) return connection except TimeoutError: - raise RedisTimeoutError("Timeout reached waiting for a free connection") + raise RedisTimeoutError( + "Timeout reached waiting for a free connection" + ) raise MaxConnectionsError() diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 804ebc5580..3889e60542 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -468,9 +468,7 @@ async def test_max_connections_waited( self, create_redis: Callable[..., RedisCluster] ) -> None: rc = await create_redis( - cls=RedisCluster, - max_connections=10, - wait_for_connections=True + cls=RedisCluster, max_connections=10, wait_for_connections=True ) for node in rc.get_nodes(): assert node.max_connections == 10 @@ -483,12 +481,8 @@ async def read_response_mocked(*args: Any, **kwargs: Any) -> None: read_response.side_effect = read_response_mocked await asyncio.gather( - *( - rc.ping(target_nodes=RedisCluster.DEFAULT_NODE) - for _ in range(20) - ) + *(rc.ping(target_nodes=RedisCluster.DEFAULT_NODE) for _ in range(20)) ) - assert len(rc.get_default_node()._connections) == 10 await rc.aclose() From 104739dbe9d064ab82e9f42a8b28209898b47122 Mon Sep 17 00:00:00 2001 From: Kevin Johnson Date: Wed, 21 Aug 2024 10:05:56 -0400 Subject: [PATCH 3/3] Update changes. --- CHANGES | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES b/CHANGES index f0d75a45ce..a5e4f526ef 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Adds capability for cluster mode to await free connection instead of raising. * Move doctests (doc code examples) to main branch * Update `ResponseT` type hint * Allow to control the minimum SSL version