diff --git a/atlassian_jwt_auth/key.py b/atlassian_jwt_auth/key.py index 1dda30d..2d182b5 100644 --- a/atlassian_jwt_auth/key.py +++ b/atlassian_jwt_auth/key.py @@ -127,9 +127,12 @@ class HTTPSMultiRepositoryPublicKeyRetriever(BasePublicKeyRetriever): repository locations based upon key ids. """ - def __init__(self, key_repository_urls): + def __init__(self, key_repository_urls, failover_on=None): if not isinstance(key_repository_urls, list): raise TypeError('keystore_urls must be a list of urls.') + self.failover_on = set() + if failover_on is not None: + self.failover_on = set(failover_on) self._retrievers = self._create_retrievers(key_repository_urls) def _create_retrievers(self, key_repository_urls): @@ -142,7 +145,9 @@ def retrieve(self, key_identifier, **requests_kwargs): return retriever.retrieve(key_identifier, **requests_kwargs) except (RequestException, PublicKeyRetrieverException) as e: if isinstance(e, PublicKeyRetrieverException): - if e.status_code is None or e.status_code < 500: + if e.status_code is None or ( + e.status_code < 500 + and e.status_code not in self.failover_on): raise logger = logging.getLogger(__name__) logger.warn('Unable to retrieve public key from store', diff --git a/atlassian_jwt_auth/tests/test_public_key_provider.py b/atlassian_jwt_auth/tests/test_public_key_provider.py index b24f4d9..1689506 100644 --- a/atlassian_jwt_auth/tests/test_public_key_provider.py +++ b/atlassian_jwt_auth/tests/test_public_key_provider.py @@ -5,6 +5,7 @@ import httptest import requests +from atlassian_jwt_auth.exceptions import PublicKeyRetrieverException from atlassian_jwt_auth.key import ( HTTPSPublicKeyRetriever, HTTPSMultiRepositoryPublicKeyRetriever, @@ -173,6 +174,42 @@ def test_retrieve(self, mock_get_method): retriever.retrieve('example/eg'), self._public_key_pem) + @mock.patch.object(requests.Session, 'get') + def test_retrieve_with_400_error(self, mock_get_method): + """ tests that the retrieve method works as expected + when the first key repository returns a generic client error + response. + """ + retriever = HTTPSMultiRepositoryPublicKeyRetriever( + self.keystore_urls) + _setup_mock_response_for_retriever( + mock_get_method, self._public_key_pem) + valid_response = mock_get_method.return_value + del mock_get_method.return_value + server_exception = requests.exceptions.HTTPError( + response=mock.Mock(status_code=400)) + mock_get_method.side_effect = [server_exception, valid_response] + with self.assertRaises(PublicKeyRetrieverException): + retriever.retrieve('example/eg') + + @mock.patch.object(requests.Session, 'get') + def test_retrieve_with_404_error(self, mock_get_method): + """ tests that the retrieve method works as expected + when the first key repository returns a not found response. + """ + retriever = HTTPSMultiRepositoryPublicKeyRetriever( + self.keystore_urls, [404]) + _setup_mock_response_for_retriever( + mock_get_method, self._public_key_pem) + valid_response = mock_get_method.return_value + del mock_get_method.return_value + server_exception = requests.exceptions.HTTPError( + response=mock.Mock(status_code=404)) + mock_get_method.side_effect = [server_exception, valid_response] + self.assertEqual( + retriever.retrieve('example/eg'), + self._public_key_pem) + @mock.patch.object(requests.Session, 'get') def test_retrieve_with_500_error(self, mock_get_method): """ tests that the retrieve method works as expected