Skip to content

Commit 64d0299

Browse files
committed
Include the ability to handle multiple sessions
1 parent 8ee23a4 commit 64d0299

File tree

6 files changed

+69
-130
lines changed

6 files changed

+69
-130
lines changed

fastapiauthenticator/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from fastapiauthenticator.enums import APIEndpoints, APIMethods # noqa: F401,E402
2-
from fastapiauthenticator.models import Params # noqa: F401,E402
2+
from fastapiauthenticator.models import Parameters # noqa: F401,E402
33
from fastapiauthenticator.service import Authenticator # noqa: F401,E402
44
from fastapiauthenticator.version import version # noqa: F401,E402
55

fastapiauthenticator/models.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,31 @@
33

44
from fastapi.routing import APIRoute, APIWebSocketRoute
55
from fastapi.templating import Jinja2Templates
6-
from pydantic import BaseModel, Field, PositiveInt
6+
from pydantic import BaseModel, Field
77

88
from fastapiauthenticator.enums import APIMethods
99

1010
templates = Jinja2Templates(directory=pathlib.Path(__file__).parent / "templates")
1111

1212

13-
class Params(BaseModel):
13+
class Parameters(BaseModel):
1414
"""Parameters for the Authenticator class.
1515
16-
>>> Params
16+
>>> Parameters
1717
1818
Attributes:
19+
path: Path for the secure route, must start with '/'.
1920
function: Function to be called for secure routes after authentication.
2021
methods: List of HTTP methods that the secure function will handle.
2122
route: Type of route to be used for secure routes, either APIWebSocketRoute or APIRoute.
22-
path: Path for the secure route, must start with '/'.
2323
"""
2424

25-
function: Callable
26-
methods: List[APIMethods] = None
27-
route: Type[APIWebSocketRoute] | Type[APIRoute]
2825
path: str = Field(
2926
pattern="^/.*$", description="Path for the secure route, must start with '/'"
3027
)
31-
timeout: PositiveInt = Field(
32-
ge=0, default=300, description="Session timeout in seconds."
33-
)
28+
function: Callable
29+
methods: List[APIMethods] = [APIMethods.GET]
30+
route: Type[APIWebSocketRoute] | Type[APIRoute] = APIRoute
3431

3532

3633
class WSSession(BaseModel):
@@ -71,15 +68,17 @@ class RedirectException(Exception):
7168
https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
7269
"""
7370

74-
def __init__(self, location: str, detail: Optional[str] = ""):
71+
def __init__(self, source: str, destination: str, detail: Optional[str] = ""):
7572
"""Instantiates the ``RedirectException`` object with the required parameters.
7673
7774
Args:
78-
location: Location for redirect.
75+
source: Source from where the redirect is initiated.
76+
destination: Location to redirect.
7977
detail: Reason for redirect.
8078
"""
81-
self.location = location
8279
self.detail = detail
80+
self.source = source
81+
self.destination = destination
8382

8483

8584
ws_session = WSSession()

fastapiauthenticator/service.py

Lines changed: 24 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import logging
22
import os
3-
from threading import Timer
43
from typing import Dict, List
54

65
import dotenv
7-
from fastapi import FastAPI
6+
from fastapi.applications import FastAPI
87
from fastapi.params import Depends
98
from fastapi.requests import Request
109
from fastapi.responses import Response
@@ -29,7 +28,8 @@ class Authenticator:
2928
def __init__(
3029
self,
3130
app: FastAPI,
32-
params: models.Params | List[models.Params],
31+
params: models.Parameters | List[models.Parameters],
32+
timeout: int = 300,
3333
username: str = os.environ.get("USERNAME"),
3434
password: str = os.environ.get("PASSWORD"),
3535
fallback_button: str = models.fallback.button,
@@ -39,6 +39,8 @@ def __init__(
3939
4040
Args:
4141
app: FastAPI application instance to which the authenticator will be added.
42+
params: Parameters for the secure routes, can be a single `Parameters` object or a list of `Parameters`.
43+
timeout: Session timeout in seconds, default is 300 seconds (5 minutes).
4244
username: Username for authentication, can be set via environment variable 'USERNAME'.
4345
password: Password for authentication, can be set via environment variable 'PASSWORD'.
4446
fallback_button: Title for the fallback button, defaults to "LOGIN".
@@ -51,10 +53,10 @@ def __init__(
5153

5254
if isinstance(params, list):
5355
self.params = params
54-
elif isinstance(params, models.Params):
56+
elif isinstance(params, models.Parameters):
5557
self.params = [params]
5658

57-
self.route_map: Dict[str, models.Params] = {
59+
self.route_map: Dict[str, models.Parameters] = {
5860
param.path: param for param in self.params if param.route is APIRoute
5961
}
6062

@@ -69,6 +71,7 @@ def __init__(
6971

7072
self.username = username
7173
self.password = password
74+
self.timeout = timeout
7275

7376
self._secure()
7477

@@ -95,84 +98,28 @@ def _verify_auth(
9598
env_username=self.username,
9699
env_password=self.password,
97100
)
98-
referer = request.headers.get("Referer")
99-
origin = request.headers.get("Origin")
100-
destination = referer.replace(origin, "")
101+
destination = request.cookies.get("X-Requested-By")
101102
parameter = self.route_map.get(destination)
102-
private_route = APIRoute(
103-
path=parameter.path,
104-
endpoint=parameter.function,
105-
methods=parameter.methods,
106-
dependencies=[Depends(utils.session_check)],
107-
)
108-
for route in self.app.routes:
109-
if route.path == private_route.path:
110-
LOGGER.info(
111-
"Route %s already exists, removing it to replace with secure route.",
112-
private_route.path,
113-
)
114-
self.app.routes.remove(route)
115-
break
116-
self.app.routes.append(private_route)
117-
LOGGER.info("Setting session timeout for %s seconds", parameter.timeout)
118-
self._handle_session(
119-
response=response,
120-
request=request,
121-
secure_route=private_route,
122-
timeout=parameter.timeout,
123-
)
124-
return {"redirect_url": parameter.path}
125-
126-
def _setup_session_route(self, secure_route: APIRoute) -> None:
127-
"""Removes the secure route and adds a routing logic for invalid sessions.
128-
129-
Args:
130-
secure_route: Secure route to be removed from the app after the session timeout.
131-
"""
132-
LOGGER.info("Session expired, removing secure route: %s", secure_route.path)
133-
self.app.routes.remove(secure_route)
134-
LOGGER.info(
135-
"Adding session route to handle expired sessions at %s", secure_route.path
136-
)
137-
self.app.routes.append(
138-
APIRoute(
139-
path=secure_route.path,
140-
endpoint=endpoints.session,
141-
methods=["GET"],
142-
)
143-
)
144-
145-
def _handle_session(
146-
self,
147-
response: Response,
148-
request: Request,
149-
secure_route: APIRoute,
150-
timeout: int,
151-
) -> None:
152-
"""Handle session management by setting a cookie and scheduling session removal.
153-
154-
Args:
155-
response: Response object to set the session cookie.
156-
request: Request object containing client information.
157-
secure_route: Secure route to be removed from the app after the session timeout.
158-
"""
159-
# Remove the secure route after the session timeout - backend
160-
Timer(
161-
function=self._setup_session_route,
162-
args=(secure_route,),
163-
interval=timeout,
164-
).start()
165-
# Set the max age in session cookie to session timeout - frontend
103+
LOGGER.info("Setting session timeout for %s seconds", self.timeout)
104+
# Set session_token cookie with a timeout, to be used for session validation when redirected
166105
response.set_cookie(
167106
key="session_token",
168107
value=models.ws_session.client_auth[request.client.host].get("token"),
169108
httponly=True,
170109
samesite="strict",
171-
max_age=timeout,
110+
max_age=self.timeout,
172111
)
112+
# todo: Session should be cleared at client side after timeout
113+
response.delete_cookie(key="X-Requested-By")
114+
return {"redirect_url": parameter.path}
173115

174116
def _secure(self) -> None:
175117
"""Create the login and verification routes for the APIAuthenticator."""
118+
login_route = APIRoute(
119+
path=enums.APIEndpoints.fastapi_login,
120+
endpoint=endpoints.login,
121+
methods=["GET"],
122+
)
176123
error_route = APIRoute(
177124
path=enums.APIEndpoints.fastapi_error,
178125
endpoint=endpoints.error,
@@ -190,6 +137,7 @@ def _secure(self) -> None:
190137
)
191138
for param in self.params:
192139
if param.route is APIWebSocketRoute:
140+
# WebSocket routes will not have a login path, they will be protected by session check
193141
secure_route = APIWebSocketRoute(
194142
path=param.path,
195143
endpoint=param.function,
@@ -198,8 +146,9 @@ def _secure(self) -> None:
198146
else:
199147
secure_route = APIRoute(
200148
path=param.path,
201-
endpoint=endpoints.login,
149+
endpoint=param.function,
202150
methods=["GET"],
151+
dependencies=[Depends(utils.session_check)],
203152
)
204153
self.app.routes.append(secure_route)
205-
self.app.routes.extend([session_route, verify_route, error_route])
154+
self.app.routes.extend([login_route, session_route, verify_route, error_route])

fastapiauthenticator/utils.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import secrets
33
from typing import Dict, List, NoReturn, Union
44

5-
from fastapi import WebSocket, status
5+
from fastapi import status
66
from fastapi.exceptions import HTTPException
77
from fastapi.requests import Request
88
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
99
from fastapi.security import HTTPAuthorizationCredentials
10+
from fastapi.websockets import WebSocket
1011

1112
from fastapiauthenticator import enums, models, secure
1213

@@ -23,8 +24,9 @@ def failed_auth_counter(host: str) -> None:
2324
models.ws_session.invalid[host] += 1
2425
except KeyError:
2526
models.ws_session.invalid[host] = 1
26-
if models.ws_session.invalid[host] >= 3:
27-
raise models.RedirectException(location=enums.APIEndpoints.fastapi_error)
27+
# todo: fix this
28+
# if models.ws_session.invalid[host] >= 3:
29+
# raise models.RedirectException(location=enums.APIEndpoints.fastapi_error)
2830

2931

3032
def redirect_exception_handler(
@@ -44,14 +46,17 @@ def redirect_exception_handler(
4446
LOGGER.warning("Exception cookies: %s", request.cookies)
4547
if request.url.path == enums.APIEndpoints.fastapi_verify_login:
4648
response = JSONResponse(
47-
content={"redirect_url": exception.location}, status_code=200
49+
content={"redirect_url": exception.destination}, status_code=200
4850
)
4951
else:
50-
response = RedirectResponse(url=exception.location)
52+
response = RedirectResponse(url=exception.destination)
5153
if exception.detail:
5254
response.set_cookie(
5355
"detail", exception.detail.upper(), httponly=True, samesite="strict"
5456
)
57+
response.set_cookie(
58+
"X-Requested-By", exception.source, httponly=True, samesite="strict"
59+
)
5560
return response
5661

5762

@@ -121,43 +126,40 @@ def verify_login(
121126
raise_error(host)
122127

123128

124-
def session_check(request: Request = None, websocket: WebSocket = None) -> None:
129+
def session_check(api_request: Request = None, api_websocket: WebSocket = None) -> None:
125130
"""Check if the session is still valid.
126131
127132
Args:
128-
request: Request containing client information.
129-
websocket: WebSocket connection object.
133+
api_request: Request containing client information.
134+
api_websocket: WebSocket connection object.
130135
131136
Raises:
132137
HTTPException: If the session is invalid or expired.
133138
"""
134-
if request:
135-
host = request.client.host
136-
session_token = request.cookies.get("session_token")
137-
elif websocket:
138-
host = websocket.client.host
139-
session_token = websocket.cookies.get("session_token")
139+
if api_request:
140+
request = api_request
141+
elif api_websocket:
142+
request = api_websocket
140143
else:
141144
raise HTTPException(
142145
status_code=status.HTTP_400_BAD_REQUEST,
143146
detail="Request or WebSocket connection is required for session check.",
144147
)
145-
stored_token = models.ws_session.client_auth.get(host, {}).get("token")
148+
session_token = request.cookies.get("session_token")
149+
stored_token = models.ws_session.client_auth.get(request.client.host, {}).get(
150+
"token"
151+
)
146152
if (
147153
stored_token
148154
and session_token
149155
and secrets.compare_digest(session_token, stored_token)
150156
):
151-
LOGGER.info("Session is valid for host: %s", host)
157+
LOGGER.info("Session is valid for host: %s", request.client.host)
152158
return
153-
# todo: this will fail all new sessions
154-
# the auth page route will be removed from the app if session1 is valid
155-
# when session2 is created, the auth page route will not be available, and since session2 is not authenticated,
156-
# the content cannot be rendered
157-
LOGGER.warning("Session is invalid for host: %s", host)
159+
LOGGER.warning("Session is invalid or expired for host: %s", request.client.host)
158160
raise models.RedirectException(
159-
location=enums.APIEndpoints.fastapi_session,
160-
detail="Session expired or invalid. Please log in again.",
161+
source=request.url.path,
162+
destination=enums.APIEndpoints.fastapi_login,
161163
)
162164

163165

@@ -172,10 +174,6 @@ def clear_session(request: Request, response: HTMLResponse) -> HTMLResponse:
172174
HTMLResponse:
173175
Returns the response object with the session token cleared.
174176
"""
175-
for cookie in request.cookies:
176-
# Deletes all cookies stored in current session
177-
LOGGER.info("Deleting cookie: '%s'", cookie)
178-
response.delete_cookie(cookie)
179177
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
180178
response.headers["Authorization"] = ""
181179
return response

verify/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
amplitude: 0,
4646
frequency: 6,
4747
color: "#ffffff",
48-
autostart: true,
48+
autostart: false,
4949
cover: true,
5050
});
5151

0 commit comments

Comments
 (0)