1414LOGGER = logging .getLogger ("uvicorn.default" )
1515
1616
17- def failed_auth_counter (host : str ) -> None :
17+ def failed_auth_counter (request : Request ) -> None :
1818 """Keeps track of failed login attempts from each host, and redirects if failed for 3 or more times.
1919
2020 Args:
21- host: Host header from the request .
21+ request: Request object containing client information .
2222 """
2323 try :
24- models .ws_session .invalid [host ] += 1
24+ models .ws_session .invalid [request . client . host ] += 1
2525 except KeyError :
26- models .ws_session .invalid [host ] = 1
27- # todo: fix this
28- # if models.ws_session.invalid[host] >= 3:
29- # raise models.RedirectException(location=enums.APIEndpoints.fastapi_error)
26+ models .ws_session .invalid [request .client .host ] = 1
27+ if models .ws_session .invalid [request .client .host ] >= 3 :
28+ raise models .RedirectException (destination = enums .APIEndpoints .fastapi_error )
3029
3130
3231def redirect_exception_handler (
@@ -42,12 +41,8 @@ def redirect_exception_handler(
4241 JSONResponse:
4342 Returns the JSONResponse with content, status code and cookie.
4443 """
45- LOGGER .warning ("Exception headers: %s" , request .headers )
46- LOGGER .warning ("Exception cookies: %s" , request .cookies )
4744 if request .url .path == enums .APIEndpoints .fastapi_verify_login :
48- response = JSONResponse (
49- content = {"redirect_url" : exception .destination }, status_code = 200
50- )
45+ response = JSONResponse (content = {"redirect_url" : exception .destination })
5146 else :
5247 response = RedirectResponse (url = exception .destination )
5348 if exception .detail :
@@ -60,16 +55,16 @@ def redirect_exception_handler(
6055 return response
6156
6257
63- def raise_error (host : str ) -> NoReturn :
58+ def raise_error (request : Request ) -> NoReturn :
6459 """Raises a 401 Unauthorized error in case of bad credentials.
6560
6661 Args:
67- host: Host header from the request .
62+ request: Request object containing client information .
6863 """
69- failed_auth_counter (host )
64+ failed_auth_counter (request )
7065 LOGGER .error (
7166 "Incorrect username or password: %d" ,
72- models .ws_session .invalid [host ],
67+ models .ws_session .invalid [request . client . host ],
7368 )
7469 raise HTTPException (
7570 status_code = status .HTTP_401_UNAUTHORIZED ,
@@ -97,33 +92,42 @@ def extract_credentials(
9792
9893def verify_login (
9994 authorization : HTTPAuthorizationCredentials ,
100- host : str ,
95+ request : Request ,
10196 env_username : str ,
10297 env_password : str ,
10398) -> Dict [str , Union [str , int ]]:
10499 """Verifies authentication and generates session token for each user.
105100
101+ Args:
102+ authorization: Authorization header from the request.
103+ request: Request object containing client information.
104+ env_username: Environment variable for the username.
105+ env_password: Environment variable for the password.
106+
106107 Returns:
107108 Dict[str, str]:
108109 Returns a dictionary with the payload required to create the session token.
109110 """
110- username , signature , timestamp = extract_credentials (authorization , host )
111+ username , signature , timestamp = extract_credentials (
112+ authorization , request .client .host
113+ )
111114 if secrets .compare_digest (username , env_username ):
112115 hex_user = secure .hex_encode (env_username )
113116 hex_pass = secure .hex_encode (env_password )
114117 else :
115118 LOGGER .warning ("User '%s' not allowed" , username )
116- raise_error (host )
119+ raise_error (request )
117120 message = f"{ hex_user } { hex_pass } { timestamp } "
118121 expected_signature = secure .calculate_hash (message )
119122 if secrets .compare_digest (signature , expected_signature ):
120- models .ws_session .invalid [host ] = 0
123+ models .ws_session .invalid [request . client . host ] = 0
121124 key = secrets .token_urlsafe (64 )
122- models .ws_session .client_auth [host ] = dict (
125+ # fixme: By setting a path instead of timestamp, this can handle path specific sessions
126+ models .ws_session .client_auth [request .client .host ] = dict (
123127 username = username , token = key , timestamp = int (timestamp )
124128 )
125- return models .ws_session .client_auth [host ]
126- raise_error (host )
129+ return models .ws_session .client_auth [request . client . host ]
130+ raise_error (request )
127131
128132
129133def session_check (api_request : Request = None , api_websocket : WebSocket = None ) -> None :
@@ -156,18 +160,31 @@ def session_check(api_request: Request = None, api_websocket: WebSocket = None)
156160 ):
157161 LOGGER .info ("Session is valid for host: %s" , request .client .host )
158162 return
159- LOGGER .warning ("Session is invalid or expired for host: %s" , request .client .host )
160- raise models .RedirectException (
161- source = request .url .path ,
162- destination = enums .APIEndpoints .fastapi_login ,
163- )
163+ elif not session_token :
164+ LOGGER .warning (
165+ "Session is invalid or expired for host: %s" , request .client .host
166+ )
167+ raise models .RedirectException (
168+ source = request .url .path ,
169+ destination = enums .APIEndpoints .fastapi_login ,
170+ )
171+ else :
172+ LOGGER .warning (
173+ "Session token mismatch for host: %s. Expected: %s, Received: %s" ,
174+ request .client .host ,
175+ stored_token ,
176+ session_token ,
177+ )
178+ raise models .RedirectException (
179+ source = request .url .path ,
180+ destination = enums .APIEndpoints .fastapi_session ,
181+ )
164182
165183
166- def clear_session ( request : Request , response : HTMLResponse ) -> HTMLResponse :
167- """Clear the session token from the response.
184+ def deauthorize ( response : HTMLResponse ) -> HTMLResponse :
185+ """Remove authorization headers and clear session token from the response.
168186
169187 Args:
170- request: FastAPI ``request`` object.
171188 response: FastAPI ``response`` object.
172189
173190 Returns:
@@ -176,4 +193,18 @@ def clear_session(request: Request, response: HTMLResponse) -> HTMLResponse:
176193 """
177194 response .headers ["Cache-Control" ] = "no-cache, no-store, must-revalidate"
178195 response .headers ["Authorization" ] = ""
196+ response .delete_cookie ("session_token" )
179197 return response
198+
199+
200+ def clear_session (host : str ) -> None :
201+ """Clear the session for the given host.
202+
203+ Args:
204+ host: Host header from the request.
205+ """
206+ if models .ws_session .client_auth .get (host ):
207+ models .ws_session .client_auth .pop (host )
208+ LOGGER .info ("Session cleared for host: %s" , host )
209+ else :
210+ LOGGER .warning ("No session found for host: %s" , host )
0 commit comments