11from datetime import timedelta
2+ from logging import getLogger
3+ from secrets import token_urlsafe
4+ from typing import List
25
36from fastapi import APIRouter , Depends , HTTPException , Request , Security
47from fastapi .responses import HTMLResponse , JSONResponse , RedirectResponse
5- from pydantic import Field , PrivateAttr
8+ from pydantic import Field , PrivateAttr , field_validator
69from starlette .status import HTTP_403_FORBIDDEN
710
811from csp_gateway .server import GatewayChannels , GatewayModule
912
13+ from ..shared import ChannelSelection
14+
1015# separate to avoid circular
1116from ..web import GatewayWebApp
1217from .hacks .api_key_middleware_websocket_fix .api_key import (
1520 APIKeyQuery ,
1621)
1722
23+ _log = getLogger (__name__ )
1824
19- class MountAPIKeyMiddleware (GatewayModule ):
20- api_key_timeout : timedelta = Field (description = "Cookie timeout for API Key authentication" , default = timedelta (hours = 12 ))
25+ __all__ = (
26+ "MountAuthMiddleware" ,
27+ "MountAPIKeyMiddleware" ,
28+ )
29+
30+ # TODO: More eventually
31+
32+
33+ class MountAuthMiddleware (GatewayModule ):
34+ enforce : list = Field (default = (), description = "Routes to enforce, default empty means 'all'" )
35+ channels : ChannelSelection = Field (
36+ default_factory = ChannelSelection ,
37+ description = "Channels or subroutes to enforce. If route is not present in `enforce`, implies 'allow all'" ,
38+ )
2139
22- # NOTE: don't make this publically configureable
23- # as it is needed in gateway.py
24- _api_key_name : str = PrivateAttr ("token" )
25- _api_key_secret : str = PrivateAttr ("" )
40+ enforce_controls : bool = Field (default = False , description = "Whether to allow access to controls routes. Defaults to True" )
41+ enforce_ui : bool = Field (default = True , description = "Whether to allow web access to the API Key authentication routes. Defaults to True" )
2642
2743 unauthorized_status_message : str = "unauthorized"
2844
45+ _enforced_channels : List [str ] = PrivateAttr (default_factory = list )
46+
2947 def connect (self , channels : GatewayChannels ) -> None :
3048 # NO-OP
3149 ...
3250
51+
52+ class MountAPIKeyMiddleware (MountAuthMiddleware ):
53+ api_key : str = Field (default = token_urlsafe (32 ), description = "API Key to use" )
54+ api_key_name : str = Field (default = "token" , description = "API Key to use" )
55+ api_key_timeout : timedelta = Field (description = "Cookie timeout for API Key authentication" , default = timedelta (hours = 12 ))
56+
57+ _instance_count = 0
58+
59+ @field_validator ("api_key_name" , mode = "before" )
60+ @classmethod
61+ def _validate_api_key_name (cls , value : str ) -> str :
62+ if not value :
63+ raise ValueError ("API Key name must be a non-empty string" )
64+ value = f"{ value .strip ().lower ()} -{ cls ._instance_count } "
65+ cls ._instance_count += 1
66+ return value
67+
3368 def rest (self , app : GatewayWebApp ) -> None :
3469 if app .settings .AUTHENTICATE :
35- # first, pull out the api key secret from the settings
36- self ._api_key_secret = app .settings .API_KEY
37-
38- # reinitialize header
39- api_key_query = APIKeyQuery (name = self ._api_key_name , auto_error = False )
40- api_key_header = APIKeyHeader (name = self ._api_key_name , auto_error = False )
41- api_key_cookie = APIKeyCookie (name = self ._api_key_name , auto_error = False )
42-
43- # routers
44- auth_router : APIRouter = app .get_router ("auth" )
45- public_router : APIRouter = app .get_router ("public" )
46-
47- # now mount middleware
48- async def get_api_key (
49- api_key_query : str = Security (api_key_query ),
50- api_key_header : str = Security (api_key_header ),
51- api_key_cookie : str = Security (api_key_cookie ),
52- ):
53- if api_key_query == self ._api_key_secret or api_key_header == self ._api_key_secret or api_key_cookie == self ._api_key_secret :
54- return self ._api_key_secret
55- else :
70+ # Use configuration to determine allowed routes
71+ # for this API key
72+ self ._calculate_auth (app )
73+
74+ # Setup the routes for authentication
75+ self ._setup_routes (app )
76+
77+ def _calculate_auth (self , app : GatewayWebApp ) -> None :
78+ self ._enforced_channels = self .channels .select_from (app .gateway .channels_model )
79+
80+ # Fully form the url
81+ self ._api_str = app .settings .API_STR
82+
83+ def _setup_routes (self , app : GatewayWebApp ) -> None :
84+ # reinitialize header
85+ api_key_query = APIKeyQuery (name = self .api_key_name , auto_error = False )
86+ api_key_header = APIKeyHeader (name = self .api_key_name , auto_error = False )
87+ api_key_cookie = APIKeyCookie (name = self .api_key_name , auto_error = False )
88+
89+ # routers
90+ auth_router : APIRouter = app .get_router ("auth" )
91+ public_router : APIRouter = app .get_router ("public" )
92+
93+ # now mount middleware
94+ async def get_api_key (
95+ request : Request = None ,
96+ api_key_query : str = Security (api_key_query ),
97+ api_key_header : str = Security (api_key_header ),
98+ api_key_cookie : str = Security (api_key_cookie ),
99+ ):
100+ if request is None :
101+ # If request is None, we are not in a request context, return None
102+ _log .warning ("API Key check: request is None, returning None" )
103+ return None
104+
105+ if hasattr (request .state , "auth" ):
106+ # Already authenticated, return the API key
107+ _log .info (f"API Key check: already authenticated, returning { self .api_key_name } " )
108+ return request .state .auth
109+
110+ resolved_path = request .url .path .rstrip ("/" ).replace (self ._api_str , "" ).lstrip ("/" ).rsplit ("/" , 1 )
111+
112+ if len (resolved_path ) == 1 :
113+ root = resolved_path [0 ]
114+ channel = ""
115+
116+ elif len (resolved_path ) > 1 :
117+ root = resolved_path [0 ]
118+ channel = resolved_path [1 ]
119+
120+ if self .enforce and root not in self .enforce :
121+ # Route not in enforce, allow
122+ _log .info (f"API Key check: { root } /{ channel } not in enforced list { self .enforce } , allowing" )
123+ return ""
124+
125+ if root == "controls" and not self .enforce_controls :
126+ # Controls route not enforced, allow
127+ _log .info (f"API Key check: root { root } not enforced, allowing" )
128+ return ""
129+
130+ # TODO
131+ if root in ("" , "auth" , "perspective" ) and not self .enforce_ui :
132+ # UI route not enforced, allow
133+ _log .info (f"API Key check: root { root } not enforced, allowing" )
134+ return ""
135+
136+ if root not in ("controls" , "auth" , "perspective" ) and channel and channel not in self ._enforced_channels :
137+ # Channel not in enforce, allow
138+ _log .info (f"API Key check: channel { root } /{ channel } not in enforced channels { self ._enforced_channels } , allowing" )
139+ return ""
140+
141+ # Else, enforce
142+ if api_key_query == self .api_key or api_key_header == self .api_key or api_key_cookie == self .api_key :
143+ # Return the API key secret to allow access
144+ _log .info (f"API Key check: { self .api_key_name } matched for { root } /{ channel } , allowing access" )
145+
146+ # NOTE: only set this if we are the one validating, not if we are ignoring
147+ request .state .auth = self .api_key
148+ return self .api_key
149+
150+ _log .warning (f"API Key check: { self .api_key_name } did not match, denying access" )
151+ raise HTTPException (
152+ status_code = HTTP_403_FORBIDDEN ,
153+ detail = self .unauthorized_status_message ,
154+ )
155+
156+ # add auth to all other routes
157+ app .add_middleware (Depends (get_api_key ))
158+
159+ if self .enforce_ui :
160+
161+ @auth_router .get ("/login" )
162+ async def route_login_and_add_cookie (api_key : str = Depends (get_api_key )):
163+ if not api_key :
56164 raise HTTPException (
57165 status_code = HTTP_403_FORBIDDEN ,
58166 detail = self .unauthorized_status_message ,
59167 )
60-
61- @auth_router .get ("/login" )
62- async def route_login_and_add_cookie (api_key : str = Depends (get_api_key )):
63168 response = RedirectResponse (url = "/" )
64169 response .set_cookie (
65- self ._api_key_name ,
170+ self .api_key_name ,
66171 value = api_key ,
67172 domain = app .settings .AUTHENTICATION_DOMAIN ,
68173 httponly = True ,
@@ -74,44 +179,40 @@ async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
74179 @auth_router .get ("/logout" )
75180 async def route_logout_and_remove_cookie ():
76181 response = RedirectResponse (url = "/login" )
77- response .delete_cookie (self ._api_key_name , domain = app .settings .AUTHENTICATION_DOMAIN )
182+ response .delete_cookie (self .api_key_name , domain = app .settings .AUTHENTICATION_DOMAIN )
78183 return response
79184
80185 # I'm hand rolling these for now...
81186 @public_router .get ("/login" , response_class = HTMLResponse , include_in_schema = False )
82187 async def get_login_page (token : str = "" , request : Request = None ):
83- if token :
84- if token != "" :
85- return RedirectResponse (url = f"{ app .settings .API_V1_STR } /auth/login?token={ token } " )
188+ if token and token != "" :
189+ return RedirectResponse (url = f"{ self ._api_str } /auth/login?token={ token } " )
86190 return app .templates .TemplateResponse (
87191 "login.html.j2" ,
88- {"request" : request , "api_key_name" : self ._api_key_name },
192+ {"request" : request , "api_key_name" : self .api_key_name },
89193 )
90194
91195 @public_router .get ("/logout" , response_class = HTMLResponse , include_in_schema = False )
92196 async def get_logout_page (request : Request = None ):
93197 return app .templates .TemplateResponse ("logout.html.j2" , {"request" : request })
94198
95- # add auth to all other routes
96- app .add_middleware (Depends (get_api_key ))
97-
98- @app .app .exception_handler (403 )
99- async def custom_403_handler (request : Request = None , * args ):
100- if "/api" in request .url .path :
101- # programmatic api access, return json
102- return JSONResponse (
103- {
104- "detail" : self .unauthorized_status_message ,
105- "status_code" : 403 ,
106- },
107- status_code = 403 ,
108- )
109- return app .templates .TemplateResponse (
110- "login.html.j2" ,
199+ @app .app .exception_handler (403 )
200+ async def custom_403_handler (request : Request = None , * args ):
201+ if "/api" in request .url .path :
202+ # programmatic api access, return json
203+ return JSONResponse (
111204 {
112- "request" : request ,
113- "api_key_name" : self ._api_key_name ,
114- "status_code" : 403 ,
115205 "detail" : self .unauthorized_status_message ,
206+ "status_code" : 403 ,
116207 },
208+ status_code = 403 ,
117209 )
210+ return app .templates .TemplateResponse (
211+ "login.html.j2" ,
212+ {
213+ "request" : request ,
214+ "api_key_name" : self .api_key_name ,
215+ "status_code" : 403 ,
216+ "detail" : self .unauthorized_status_message ,
217+ },
218+ )
0 commit comments