Skip to content

Wip allow specifying status for fetcher failover #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions atlassian_jwt_auth/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ class HTTPSMultiRepositoryPublicKeyRetriever(BasePublicKeyRetriever):
repository locations based upon key ids.
"""

def __init__(self, key_repository_urls):
def __init__(self, key_repository_urls, failover_on=None):
if not isinstance(key_repository_urls, list):
raise TypeError('keystore_urls must be a list of urls.')
self.failover_on = set()
if failover_on is not None:
self.failover_on = set(failover_on)
self._retrievers = self._create_retrievers(key_repository_urls)

def _create_retrievers(self, key_repository_urls):
Expand All @@ -142,7 +145,9 @@ def retrieve(self, key_identifier, **requests_kwargs):
return retriever.retrieve(key_identifier, **requests_kwargs)
except (RequestException, PublicKeyRetrieverException) as e:
if isinstance(e, PublicKeyRetrieverException):
if e.status_code is None or e.status_code < 500:
if e.status_code is None or (
e.status_code < 500
and e.status_code not in self.failover_on):
raise
logger = logging.getLogger(__name__)
logger.warn('Unable to retrieve public key from store',
Expand Down
37 changes: 37 additions & 0 deletions atlassian_jwt_auth/tests/test_public_key_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import httptest
import requests

from atlassian_jwt_auth.exceptions import PublicKeyRetrieverException
from atlassian_jwt_auth.key import (
HTTPSPublicKeyRetriever,
HTTPSMultiRepositoryPublicKeyRetriever,
Expand Down Expand Up @@ -173,6 +174,42 @@ def test_retrieve(self, mock_get_method):
retriever.retrieve('example/eg'),
self._public_key_pem)

@mock.patch.object(requests.Session, 'get')
def test_retrieve_with_400_error(self, mock_get_method):
""" tests that the retrieve method works as expected
when the first key repository returns a generic client error
response.
"""
retriever = HTTPSMultiRepositoryPublicKeyRetriever(
self.keystore_urls)
_setup_mock_response_for_retriever(
mock_get_method, self._public_key_pem)
valid_response = mock_get_method.return_value
del mock_get_method.return_value
server_exception = requests.exceptions.HTTPError(
response=mock.Mock(status_code=400))
mock_get_method.side_effect = [server_exception, valid_response]
with self.assertRaises(PublicKeyRetrieverException):
retriever.retrieve('example/eg')

@mock.patch.object(requests.Session, 'get')
def test_retrieve_with_404_error(self, mock_get_method):
""" tests that the retrieve method works as expected
when the first key repository returns a not found response.
"""
retriever = HTTPSMultiRepositoryPublicKeyRetriever(
self.keystore_urls, [404])
_setup_mock_response_for_retriever(
mock_get_method, self._public_key_pem)
valid_response = mock_get_method.return_value
del mock_get_method.return_value
server_exception = requests.exceptions.HTTPError(
response=mock.Mock(status_code=404))
mock_get_method.side_effect = [server_exception, valid_response]
self.assertEqual(
retriever.retrieve('example/eg'),
self._public_key_pem)

@mock.patch.object(requests.Session, 'get')
def test_retrieve_with_500_error(self, mock_get_method):
""" tests that the retrieve method works as expected
Expand Down