11import logging
22import secrets
3- from typing import Dict , List , NoReturn , Union
3+ from typing import List , NoReturn
44
55from fastapi import status
66from fastapi .exceptions import HTTPException
@@ -73,17 +73,12 @@ def raise_error(request: Request) -> NoReturn:
7373 )
7474
7575
76- def extract_credentials (
77- authorization : HTTPAuthorizationCredentials , host : str
78- ) -> List [str ]:
76+ def extract_credentials (authorization : HTTPAuthorizationCredentials ) -> List [str ]:
7977 """Extract the credentials from ``Authorization`` headers and decode it before returning as a list of strings.
8078
8179 Args:
8280 authorization: Authorization header from the request.
83- host: Host header from the request.
8481 """
85- if not authorization :
86- raise_error (host )
8782 decoded_auth = secure .base64_decode (authorization .credentials )
8883 # convert hex to a string
8984 auth = secure .hex_decode (decoded_auth )
@@ -95,7 +90,7 @@ def verify_login(
9590 request : Request ,
9691 env_username : str ,
9792 env_password : str ,
98- ) -> Dict [ str , Union [ str , int ]] :
93+ ) -> str | NoReturn :
9994 """Verifies authentication and generates session token for each user.
10095
10196 Args:
@@ -105,12 +100,13 @@ def verify_login(
105100 env_password: Environment variable for the password.
106101
107102 Returns:
108- Dict[ str, str] :
109- Returns a dictionary with the payload required to create the session token.
103+ str:
104+ Returns the session token.
110105 """
111- username , signature , timestamp = extract_credentials (
112- authorization , request .client .host
113- )
106+ if authorization :
107+ username , signature , timestamp = extract_credentials (authorization )
108+ else :
109+ raise_error (request )
114110 if secrets .compare_digest (username , env_username ):
115111 hex_user = secure .hex_encode (env_username )
116112 hex_pass = secure .hex_encode (env_password )
@@ -122,15 +118,14 @@ def verify_login(
122118 if secrets .compare_digest (signature , expected_signature ):
123119 models .ws_session .invalid [request .client .host ] = 0
124120 key = secrets .token_urlsafe (64 )
125- # fixme: By setting a path instead of timestamp, this can handle path specific sessions
126- models .ws_session .client_auth [request .client .host ] = dict (
127- username = username , token = key , timestamp = int (timestamp )
128- )
129- return models .ws_session .client_auth [request .client .host ]
121+ models .ws_session .client_auth [request .client .host ] = key
122+ return key
130123 raise_error (request )
131124
132125
133- def session_check (api_request : Request = None , api_websocket : WebSocket = None ) -> None :
126+ def verify_session (
127+ api_request : Request = None , api_websocket : WebSocket = None
128+ ) -> None :
134129 """Check if the session is still valid.
135130
136131 Args:
@@ -150,9 +145,7 @@ def session_check(api_request: Request = None, api_websocket: WebSocket = None)
150145 detail = "Request or WebSocket connection is required for session check." ,
151146 )
152147 session_token = request .cookies .get ("session_token" )
153- stored_token = models .ws_session .client_auth .get (request .client .host , {}).get (
154- "token"
155- )
148+ stored_token = models .ws_session .client_auth .get (request .client .host )
156149 if (
157150 stored_token
158151 and session_token
0 commit comments