Skip to content

Commit

Permalink
Checking for missing keys earlier
Browse files Browse the repository at this point in the history
This will help catch errors where public keys are not retrievable. This is a partial solution to #283.
  • Loading branch information
Bento007 committed Oct 26, 2019
1 parent daa6ef8 commit e3d4f3f
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions fusillade/utils/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@functools.lru_cache(maxsize=32)
def get_openid_config(openid_provider=None):
def get_openid_config(openid_provider=None) -> dict:
"""
:param openid_provider: the openid provider's domain.
Expand All @@ -43,28 +43,36 @@ def get_openid_config(openid_provider=None):
return res.json()


def get_jwks_uri(openid_provider):
def get_jwks_uri(openid_provider) -> str:
if openid_provider.endswith(gserviceaccount_domain):
return f"https://www.googleapis.com/service_accounts/v1/jwk/{openid_provider}"
else:
return get_openid_config(openid_provider)["jwks_uri"]


@functools.lru_cache(maxsize=32)
def get_public_keys(openid_provider):
def get_public_key(issuer, kid: str) -> bytearray:
"""
Fetches the public key from an OIDC Identity provider to verify the JWT.
:param openid_provider: the openid provider's domain.
:return: Public Keys
:param issuer: the openid provider's domain.
:param kid: the key identifier for verifying the JWT
:return: A Public Key
"""
keys = session.get(get_jwks_uri(openid_provider)).json()["keys"]
return {
public_keys = {
key["kid"]: rsa.RSAPublicNumbers(
e=int.from_bytes(base64.urlsafe_b64decode(key["e"] + "==="), byteorder="big"),
n=int.from_bytes(base64.urlsafe_b64decode(key["n"] + "==="), byteorder="big")
).public_key(backend=default_backend())
for key in keys
for key in session.get(get_jwks_uri(issuer)).json()["keys"]
}
try:
return public_keys[kid]
except KeyError:
logger.error({"message": "Failed to fetched public key from openid provider.",
"public_keys": public_keys,
"issuer": issuer,
"kid": kid})
raise FusilladeHTTPException(503, 'Service Unavailable', "Failed to fetched public key from openid provider.")


def verify_jwt(token: str) -> typing.Optional[typing.Mapping]:
Expand All @@ -84,15 +92,7 @@ def verify_jwt(token: str) -> typing.Optional[typing.Mapping]:
raise FusilladeHTTPException(401, 'Unauthorized', 'Failed to decode token.')

issuer = unverified_token['iss']
public_keys = get_public_keys(issuer)
try:
public_key = public_keys[token_header["kid"]]
except KeyError:
logger.error({"message": "Failed to fetched public key from openid provider.",
"public_keys": public_keys,
"issuer": issuer,
"kid": token_header["kid"]})
raise FusilladeHTTPException(503, 'Service Unavailable', "Failed to fetched public key from openid provider.")
public_key = get_public_key(issuer, token_header["kid"])
try:
verified_tok = jwt.decode(token,
key=public_key,
Expand Down

0 comments on commit e3d4f3f

Please sign in to comment.