diff --git a/fusillade/utils/security.py b/fusillade/utils/security.py index 946f084f..6cd8f6bf 100644 --- a/fusillade/utils/security.py +++ b/fusillade/utils/security.py @@ -26,7 +26,7 @@ @functools.lru_cache(maxsize=32) -def get_openid_config(openid_provider=None): +def get_openid_config(openid_provider=None) -> dict: """ :param openid_provider: the openid provider's domain. @@ -43,7 +43,7 @@ def get_openid_config(openid_provider=None): return res.json() -def get_jwks_uri(openid_provider): +def get_jwks_uri(openid_provider) -> str: if openid_provider.endswith(gserviceaccount_domain): return f"https://www.googleapis.com/service_accounts/v1/jwk/{openid_provider}" else: @@ -51,20 +51,28 @@ def get_jwks_uri(openid_provider): @functools.lru_cache(maxsize=32) -def get_public_keys(openid_provider): +def get_public_key(issuer, kid: str) -> bytearray: """ Fetches the public key from an OIDC Identity provider to verify the JWT. - :param openid_provider: the openid provider's domain. - :return: Public Keys + :param issuer: the openid provider's domain. + :param kid: the key identifier for verifying the JWT + :return: A Public Key """ - keys = session.get(get_jwks_uri(openid_provider)).json()["keys"] - return { + public_keys = { key["kid"]: rsa.RSAPublicNumbers( e=int.from_bytes(base64.urlsafe_b64decode(key["e"] + "==="), byteorder="big"), n=int.from_bytes(base64.urlsafe_b64decode(key["n"] + "==="), byteorder="big") ).public_key(backend=default_backend()) - for key in keys + for key in session.get(get_jwks_uri(issuer)).json()["keys"] } + try: + return public_keys[kid] + except KeyError: + logger.error({"message": "Failed to fetched public key from openid provider.", + "public_keys": public_keys, + "issuer": issuer, + "kid": kid}) + raise FusilladeHTTPException(503, 'Service Unavailable', "Failed to fetched public key from openid provider.") def verify_jwt(token: str) -> typing.Optional[typing.Mapping]: @@ -84,15 +92,7 @@ def verify_jwt(token: str) -> typing.Optional[typing.Mapping]: raise FusilladeHTTPException(401, 'Unauthorized', 'Failed to decode token.') issuer = unverified_token['iss'] - public_keys = get_public_keys(issuer) - try: - public_key = public_keys[token_header["kid"]] - except KeyError: - logger.error({"message": "Failed to fetched public key from openid provider.", - "public_keys": public_keys, - "issuer": issuer, - "kid": token_header["kid"]}) - raise FusilladeHTTPException(503, 'Service Unavailable', "Failed to fetched public key from openid provider.") + public_key = get_public_key(issuer, token_header["kid"]) try: verified_tok = jwt.decode(token, key=public_key,