Skip to content

Commit

Permalink
feature: host fallback support (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna authored Dec 10, 2021
1 parent c65f5ee commit 1900f4a
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 59 deletions.
11 changes: 8 additions & 3 deletions arango/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class ArangoClient:
multiple host URLs are provided). Accepted values are "roundrobin" and
"random". Any other value defaults to round robin.
:type host_resolver: str
:param resolver_max_tries: Number of attempts to process an HTTP request
before throwing a ConnectionAbortedError. Must not be lower than the
number of hosts.
:type resolver_max_tries: int
:param http_client: User-defined HTTP client.
:type http_client: arango.http.HTTPClient
:param serializer: User-defined JSON serializer. Must be a callable
Expand All @@ -48,6 +52,7 @@ def __init__(
self,
hosts: Union[str, Sequence[str]] = "http://127.0.0.1:8529",
host_resolver: str = "roundrobin",
resolver_max_tries: Optional[int] = None,
http_client: Optional[HTTPClient] = None,
serializer: Callable[..., str] = lambda x: dumps(x),
deserializer: Callable[[str], Any] = lambda x: loads(x),
Expand All @@ -61,11 +66,11 @@ def __init__(
self._host_resolver: HostResolver

if host_count == 1:
self._host_resolver = SingleHostResolver()
self._host_resolver = SingleHostResolver(1, resolver_max_tries)
elif host_resolver == "random":
self._host_resolver = RandomHostResolver(host_count)
self._host_resolver = RandomHostResolver(host_count, resolver_max_tries)
else:
self._host_resolver = RoundRobinHostResolver(host_count)
self._host_resolver = RoundRobinHostResolver(host_count, resolver_max_tries)

self._http = http_client or DefaultHTTPClient()
self._serializer = serializer
Expand Down
96 changes: 50 additions & 46 deletions arango/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
"JwtSuperuserConnection",
]

import logging
import sys
import time
from abc import abstractmethod
from typing import Any, Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union

import jwt
from requests import Session
from requests import ConnectionError, Session
from requests_toolbelt import MultipartEncoder

from arango.exceptions import JWTAuthError, ServerConnectionError
Expand Down Expand Up @@ -110,6 +111,48 @@ def prep_response(self, resp: Response, deserialize: bool = True) -> Response:
resp.is_success = http_ok and resp.error_code is None
return resp

def process_request(
self, host_index: int, request: Request, auth: Optional[Tuple[str, str]] = None
) -> Response:
"""Execute a request until a valid response has been returned.
:param host_index: The index of the first host to try
:type host_index: int
:param request: HTTP request.
:type request: arango.request.Request
:return: HTTP response.
:rtype: arango.response.Response
"""
tries = 0
indexes_to_filter: Set[int] = set()
while tries < self._host_resolver.max_tries:
try:
resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
auth=auth,
)

return self.prep_response(resp, request.deserialize)
except ConnectionError:
url = self._url_prefixes[host_index] + request.endpoint
logging.debug(f"ConnectionError: {url}")

if len(indexes_to_filter) == self._host_resolver.host_count - 1:
indexes_to_filter.clear()
indexes_to_filter.add(host_index)

host_index = self._host_resolver.get_host_index(indexes_to_filter)
tries += 1

raise ConnectionAbortedError(
f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
)

def prep_bulk_err_response(self, parent_response: Response, body: Json) -> Response:
"""Build and return a bulk error response.
Expand Down Expand Up @@ -227,16 +270,7 @@ def send_request(self, request: Request) -> Response:
:rtype: arango.response.Response
"""
host_index = self._host_resolver.get_host_index()
resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
auth=self._auth,
)
return self.prep_response(resp, request.deserialize)
return self.process_request(host_index, request, auth=self._auth)


class JwtConnection(BaseConnection):
Expand Down Expand Up @@ -302,15 +336,7 @@ def send_request(self, request: Request) -> Response:
if self._auth_header is not None:
request.headers["Authorization"] = self._auth_header

resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
)
resp = self.prep_response(resp, request.deserialize)
resp = self.process_request(host_index, request)

# Refresh the token and retry on HTTP 401 and error code 11.
if resp.error_code != 11 or resp.status_code != 401:
Expand All @@ -325,15 +351,7 @@ def send_request(self, request: Request) -> Response:
if self._auth_header is not None:
request.headers["Authorization"] = self._auth_header

resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
)
return self.prep_response(resp, request.deserialize)
return self.process_request(host_index, request)

def refresh_token(self) -> None:
"""Get a new JWT token for the current user (cannot be a superuser).
Expand All @@ -349,13 +367,7 @@ def refresh_token(self) -> None:

host_index = self._host_resolver.get_host_index()

resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
data=self.normalize_data(request.data),
)
resp = self.prep_response(resp)
resp = self.process_request(host_index, request)

if not resp.is_success:
raise JWTAuthError(resp, request)
Expand Down Expand Up @@ -429,12 +441,4 @@ def send_request(self, request: Request) -> Response:
host_index = self._host_resolver.get_host_index()
request.headers["Authorization"] = self._auth_header

resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
)
return self.prep_response(resp, request.deserialize)
return self.process_request(host_index, request)
42 changes: 32 additions & 10 deletions arango/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,62 @@

import random
from abc import ABC, abstractmethod
from typing import Optional, Set


class HostResolver(ABC): # pragma: no cover
"""Abstract base class for host resolvers."""

def __init__(self, host_count: int = 1, max_tries: Optional[int] = None) -> None:
max_tries = max_tries or host_count * 3
if max_tries < host_count:
raise ValueError("max_tries cannot be less than host_count")

self._host_count = host_count
self._max_tries = max_tries

@abstractmethod
def get_host_index(self) -> int:
def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int:
raise NotImplementedError

@property
def host_count(self) -> int:
return self._host_count

@property
def max_tries(self) -> int:
return self._max_tries


class SingleHostResolver(HostResolver):
"""Single host resolver."""

def get_host_index(self) -> int:
def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int:
return 0


class RandomHostResolver(HostResolver):
"""Random host resolver."""

def __init__(self, host_count: int) -> None:
self._max = host_count - 1
def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
super().__init__(host_count, max_tries)

def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int:
host_index = None
indexes_to_filter = indexes_to_filter or set()
while host_index is None or host_index in indexes_to_filter:
host_index = random.randint(0, self.host_count - 1)

def get_host_index(self) -> int:
return random.randint(0, self._max)
return host_index


class RoundRobinHostResolver(HostResolver):
"""Round-robin host resolver."""

def __init__(self, host_count: int) -> None:
def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
super().__init__(host_count, max_tries)
self._index = -1
self._count = host_count

def get_host_index(self) -> int:
self._index = (self._index + 1) % self._count
def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int:
self._index = (self._index + 1) % self.host_count
return self._index
24 changes: 24 additions & 0 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import Set

import pytest

from arango.resolver import (
RandomHostResolver,
RoundRobinHostResolver,
SingleHostResolver,
)


def test_bad_resolver():
with pytest.raises(ValueError):
RandomHostResolver(3, 2)


def test_resolver_single_host():
resolver = SingleHostResolver()
for _ in range(20):
Expand All @@ -16,6 +25,21 @@ def test_resolver_random_host():
for _ in range(20):
assert 0 <= resolver.get_host_index() < 10

resolver = RandomHostResolver(3)
indexes_to_filter: Set[int] = set()

index_a = resolver.get_host_index()
indexes_to_filter.add(index_a)

index_b = resolver.get_host_index(indexes_to_filter)
indexes_to_filter.add(index_b)
assert index_b != index_a

index_c = resolver.get_host_index(indexes_to_filter)
indexes_to_filter.clear()
indexes_to_filter.add(index_c)
assert index_c not in [index_a, index_b]


def test_resolver_round_robin():
resolver = RoundRobinHostResolver(10)
Expand Down

0 comments on commit 1900f4a

Please sign in to comment.