1
+ import logging
1
2
from functools import cached_property
2
3
from typing import Any , Generic , TypeVar
3
4
6
7
from cachetools import TTLCache , cached
7
8
from fastapi import HTTPException
8
9
from jwt import algorithms
9
- from loguru import logger
10
10
from pydantic import BaseModel
11
11
12
12
from fastapi_jwks .models .types import JWKSConfig , JWTDecodeConfig
13
13
14
+ logger = logging .getLogger ("fastapi-jwks" )
15
+
14
16
DataT = TypeVar ("DataT" , bound = BaseModel )
15
17
16
18
@@ -41,10 +43,7 @@ def jwks_data(self) -> dict[str, Any]:
41
43
42
44
@staticmethod
43
45
def __extract_algorithms (jwks_response : dict [str , Any ]) -> list [str ]:
44
- if "keys" not in jwks_response :
45
- raise ValueError ("JWKS response does not contain keys" )
46
- keys = jwks_response ["keys" ]
47
- return [key ["alg" ] for key in keys ]
46
+ return [key ["alg" ] for key in jwks_response ["keys" ]]
48
47
49
48
@cached_property
50
49
def __is_generic_passed (self ):
@@ -63,7 +62,11 @@ def validate_token(self, token: str) -> DataT:
63
62
header = jwt .get_unverified_header (token )
64
63
kid = header ["kid" ]
65
64
jwks_data = self .jwks_data ()
66
- if header ["alg" ] not in self .__extract_algorithms (jwks_data ):
65
+ provided_algorithms = self .__extract_algorithms (jwks_data )
66
+ if header ["alg" ] not in provided_algorithms :
67
+ logger .debug (
68
+ f"Could not find '{ header ['alg' ]} ' in provided algorithms: { provided_algorithms } "
69
+ )
67
70
raise HTTPException (status_code = 401 , detail = "Invalid token" )
68
71
for key in jwks_data ["keys" ]:
69
72
if key ["kid" ] == kid :
@@ -72,6 +75,9 @@ def validate_token(self, token: str) -> DataT:
72
75
].from_jwk (key )
73
76
break
74
77
if public_key is None :
78
+ logger .debug (
79
+ f"No public key for provided algorithm '{ header ['alg' ]} ' found in JWKS data"
80
+ )
75
81
raise HTTPException (status_code = 401 , detail = "Invalid token" )
76
82
return self .__orig_class__ .__args__ [0 ].model_validate ( # type: ignore
77
83
# This line gets the generic value in runtime to transform it to the correct pydantic model
@@ -83,6 +89,8 @@ def validate_token(self, token: str) -> DataT:
83
89
)
84
90
)
85
91
except jwt .ExpiredSignatureError :
92
+ logger .debug ("Expired token" , exc_info = True )
86
93
raise HTTPException (status_code = 401 , detail = "Token has expired" ) from None
87
94
except jwt .InvalidTokenError :
95
+ logger .debug ("Invalid token" , exc_info = True )
88
96
raise HTTPException (status_code = 401 , detail = "Invalid token" ) from None
0 commit comments