Skip to content

Added support for Pipeline and transactions #3707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: vv-multi-db-client
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions redis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,21 @@ def nodes(self) -> dict:
def credential_provider(self) -> Union[CredentialProvider, None]:
return self._credential_provider

class OnCommandFailEvent:
class OnCommandsFailEvent:
"""
Event fired whenever a command fails during the execution.
"""
def __init__(
self,
command: tuple,
commands: tuple,
exception: Exception,
):
self._command = command
self._commands = commands
self._exception = exception

@property
def command(self) -> tuple:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this also be changed to commands?

return self._command
return self._commands

@property
def exception(self) -> Exception:
Expand Down
116 changes: 97 additions & 19 deletions redis/multidb/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
import socket
from typing import Callable
from typing import List, Any, Callable

from redis.background import BackgroundScheduler
from redis.exceptions import ConnectionError, TimeoutError
Expand Down Expand Up @@ -30,23 +30,22 @@ def __init__(self, config: MultiDbConfig):
self._failover_strategy.set_databases(self._databases)
self._auto_fallback_interval = config.auto_fallback_interval
self._event_dispatcher = config.event_dispatcher
self._command_executor = DefaultCommandExecutor(
self._command_retry = config.command_retry
self._command_retry.update_supported_errors((ConnectionRefusedError,))
self.command_executor = DefaultCommandExecutor(
failure_detectors=self._failure_detectors,
databases=self._databases,
command_retry=config.command_retry,
command_retry=self._command_retry,
failover_strategy=self._failover_strategy,
event_dispatcher=self._event_dispatcher,
auto_fallback_interval=self._auto_fallback_interval,
)

for fd in self._failure_detectors:
fd.set_command_executor(command_executor=self._command_executor)

self._initialized = False
self.initialized = False
self._hc_lock = threading.RLock()
self._bg_scheduler = BackgroundScheduler()
self._config = config

def _initialize(self):
def initialize(self):
"""
Perform initialization of databases to define their initial state.
"""
Expand All @@ -72,7 +71,7 @@ def raise_exception_on_failed_hc(error):
# Set states according to a weights and circuit state
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
database.state = DBState.ACTIVE
self._command_executor.active_database = database
self.command_executor.active_database = database
is_active_db_found = True
elif database.circuit.state == CBState.CLOSED and is_active_db_found:
database.state = DBState.PASSIVE
Expand All @@ -82,7 +81,7 @@ def raise_exception_on_failed_hc(error):
if not is_active_db_found:
raise NoValidDatabaseException('Initial connection failed - no active database found')

self._initialized = True
self.initialized = True

def get_databases(self) -> Databases:
"""
Expand Down Expand Up @@ -110,7 +109,7 @@ def set_active_database(self, database: AbstractDatabase) -> None:
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
highest_weighted_db.state = DBState.PASSIVE
database.state = DBState.ACTIVE
self._command_executor.active_database = database
self.command_executor.active_database = database
return

raise NoValidDatabaseException('Cannot set active database, database is unhealthy')
Expand All @@ -132,7 +131,7 @@ def add_database(self, database: AbstractDatabase):
def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase):
if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED:
new_database.state = DBState.ACTIVE
self._command_executor.active_database = new_database
self.command_executor.active_database = new_database
highest_weight_database.state = DBState.PASSIVE

def remove_database(self, database: Database):
Expand All @@ -144,7 +143,7 @@ def remove_database(self, database: Database):

if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED:
highest_weighted_db.state = DBState.ACTIVE
self._command_executor.active_database = highest_weighted_db
self.command_executor.active_database = highest_weighted_db

def update_database_weight(self, database: AbstractDatabase, weight: float):
"""
Expand Down Expand Up @@ -182,10 +181,25 @@ def execute_command(self, *args, **options):
"""
Executes a single command and return its result.
"""
if not self._initialized:
self._initialize()
if not self.initialized:
self.initialize()

return self.command_executor.execute_command(*args, **options)

def pipeline(self):
"""
Enters into pipeline mode of the client.
"""
return Pipeline(self)

return self._command_executor.execute_command(*args, **options)
def transaction(self, func: Callable[["Pipeline"], None], *watches, **options):
"""
Executes callable as transaction.
"""
if not self.initialized:
self.initialize()

return self.command_executor.execute_transaction(func, *watches, *options)

def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None:
"""
Expand All @@ -207,7 +221,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep
database.circuit.state = CBState.OPEN
elif is_healthy and database.circuit.state != CBState.CLOSED:
database.circuit.state = CBState.CLOSED
except (ConnectionError, TimeoutError, socket.timeout) as e:
except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError) as e:
if database.circuit.state != CBState.OPEN:
database.circuit.state = CBState.OPEN
is_healthy = False
Expand All @@ -219,7 +233,9 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep
def _check_databases_health(self, on_error: Callable[[Exception], None] = None):
"""
Runs health checks as a recurring task.
Runs health checks against all databases.
"""

for database, _ in self._databases:
self._check_db_health(database, on_error)

Expand All @@ -232,4 +248,66 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state:
self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)

def _half_open_circuit(circuit: CircuitBreaker):
circuit.state = CBState.HALF_OPEN
circuit.state = CBState.HALF_OPEN


class Pipeline(RedisModuleCommands, CoreCommands, SentinelCommands):
"""
Pipeline implementation for multiple logical Redis databases.
"""
def __init__(self, client: MultiDBClient):
self._command_stack = []
self._client = client

def __enter__(self) -> "Pipeline":
return self

def __exit__(self, exc_type, exc_value, traceback):
self.reset()

def __del__(self):
try:
self.reset()
except Exception:
pass

def __len__(self) -> int:
return len(self._command_stack)

def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True

def reset(self) -> None:
self._command_stack = []

def close(self) -> None:
"""Close the pipeline"""
self.reset()

def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called

Returns the current Pipeline object back so commands can be
chained together, such as:

pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')

At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self._command_stack.append((args, options))
return self

def execute_command(self, *args, **kwargs):
return self.pipeline_execute_command(*args, **kwargs)

def execute(self) -> List[Any]:
if not self._client.initialized:
self._client.initialize()

try:
return self._client.command_executor.execute_pipeline(tuple(self._command_stack))
finally:
self.reset()
56 changes: 45 additions & 11 deletions redis/multidb/command_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import List, Union, Optional
from typing import List, Union, Optional, Callable

from redis.event import EventDispatcherInterface, OnCommandFailEvent
from redis.client import Pipeline
from redis.event import EventDispatcherInterface, OnCommandsFailEvent
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.multidb.database import Database, AbstractDatabase, Databases
from redis.multidb.circuit import State as CBState
Expand Down Expand Up @@ -92,6 +93,9 @@ def __init__(
:param auto_fallback_interval: Interval between fallback attempts. Fallback to a new database according to
failover_strategy.
"""
for fd in failure_detectors:
fd.set_command_executor(command_executor=self)

self._failure_detectors = failure_detectors
self._databases = databases
self._command_retry = command_retry
Expand Down Expand Up @@ -139,19 +143,49 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None:
self._auto_fallback_interval = auto_fallback_interval

def execute_command(self, *args, **options):
self._check_active_database()
def callback():
return self._active_database.client.execute_command(*args, **options)

return self._execute_with_failure_detection(callback, args)

def execute_pipeline(self, command_stack: tuple):
"""
Executes a stack of commands in pipeline.
"""
def callback():
with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
pipe.execute_command(*command, **options)

return pipe.execute()

return self._execute_with_failure_detection(callback, command_stack)

def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options):
"""
Executes a transaction block wrapped in callback.
"""
def callback():
return self._active_database.client.transaction(transaction, *watches, **options)

return self._execute_with_failure_detection(callback)

def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()):
"""
Execute a commands execution callback with failure detection.
"""
def wrapper():
# On each retry we need to check active database as it might change.
self._check_active_database()
return callback()

return self._command_retry.call_with_retry(
lambda: self._execute_command(*args, **options),
lambda error: self._on_command_fail(error, *args),
lambda: wrapper(),
lambda error: self._on_command_fail(error, *cmds),
)

def _execute_command(self, *args, **options):
self._check_active_database()
return self._active_database.client.execute_command(*args, **options)

def _on_command_fail(self, error, *args):
self._event_dispatcher.dispatch(OnCommandFailEvent(args, error))
self._event_dispatcher.dispatch(OnCommandsFailEvent(args, error))

def _check_active_database(self):
"""
Expand Down Expand Up @@ -180,5 +214,5 @@ def _setup_event_dispatcher(self):
"""
event_listener = RegisterCommandFailure(self._failure_detectors)
self._event_dispatcher.register_listeners({
OnCommandFailEvent: [event_listener],
OnCommandsFailEvent: [event_listener],
})
4 changes: 2 additions & 2 deletions redis/multidb/event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from redis.event import EventListenerInterface, OnCommandFailEvent
from redis.event import EventListenerInterface, OnCommandsFailEvent
from redis.multidb.failure_detector import FailureDetector


Expand All @@ -11,6 +11,6 @@ class RegisterCommandFailure(EventListenerInterface):
def __init__(self, failure_detectors: List[FailureDetector]):
self._failure_detectors = failure_detectors

def listen(self, event: OnCommandFailEvent) -> None:
def listen(self, event: OnCommandsFailEvent) -> None:
for failure_detector in self._failure_detectors:
failure_detector.register_failure(event.exception, event.command)
6 changes: 3 additions & 3 deletions tests/test_multidb/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pybreaker
import pytest

from redis.event import EventDispatcher, OnCommandFailEvent
from redis.event import EventDispatcher, OnCommandsFailEvent
from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter
from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \
DEFAULT_FAILOVER_BACKOFF
Expand Down Expand Up @@ -455,8 +455,8 @@ def test_add_new_failure_detector(
mock_fd = mock_multi_db_config.failure_detectors[0]

# Event fired if command against mock_db1 would fail
command_fail_event = OnCommandFailEvent(
command=('SET', 'key', 'value'),
command_fail_event = OnCommandsFailEvent(
commands=('SET', 'key', 'value'),
exception=Exception(),
)

Expand Down
1 change: 0 additions & 1 deletion tests/test_multidb/test_command_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def test_execute_command_fallback_to_another_db_after_failure_detection(
auto_fallback_interval=0.1,
command_retry=Retry(NoBackoff(), threshold),
)
fd.set_command_executor(command_executor=executor)

assert executor.execute_command('SET', 'key', 'value') == 'OK1'
assert executor.execute_command('SET', 'key', 'value') == 'OK2'
Expand Down
Loading