11from datetime import timedelta
2+ from logging import getLogger
3+ from secrets import token_urlsafe
4+ from typing import List
25
36from fastapi import APIRouter , Depends , HTTPException , Request , Security
47from fastapi .responses import HTMLResponse , JSONResponse , RedirectResponse
5- from pydantic import Field , PrivateAttr
8+ from pydantic import Field , PrivateAttr , field_validator
69from starlette .status import HTTP_403_FORBIDDEN
710
811from csp_gateway .server import GatewayChannels , GatewayModule
912
13+ from ..shared import ChannelSelection
14+
1015# separate to avoid circular
1116from ..web import GatewayWebApp
1217from .hacks .api_key_middleware_websocket_fix .api_key import (
1520 APIKeyQuery ,
1621)
1722
23+ _log = getLogger (__name__ )
1824
19- class MountAPIKeyMiddleware (GatewayModule ):
20- api_key_timeout : timedelta = Field (description = "Cookie timeout for API Key authentication" , default = timedelta (hours = 12 ))
25+ __all__ = (
26+ "MountAuthMiddleware" ,
27+ "MountAPIKeyMiddleware" ,
28+ )
2129
22- # NOTE: don't make this publically configureable
23- # as it is needed in gateway.py
24- _api_key_name : str = PrivateAttr ("token" )
25- _api_key_secret : str = PrivateAttr ("" )
30+ # TODO: More eventually
31+
32+ class MountAuthMiddleware (GatewayModule ):
33+ enforce : list = Field (default = (), description = "Routes to enforce, default empty means 'all'" )
34+ channels : ChannelSelection = Field (
35+ default_factory = ChannelSelection ,
36+ description = "Channels or subroutes to enforce. If route is not present in `enforce`, implies 'allow all'" ,
37+ )
38+
39+ enforce_controls : bool = Field (default = False , description = "Whether to allow access to controls routes. Defaults to True" )
40+ enforce_ui : bool = Field (default = True , description = "Whether to allow web access to the API Key authentication routes. Defaults to True" )
2641
2742 unauthorized_status_message : str = "unauthorized"
2843
44+ _enforced_channels : List [str ] = PrivateAttr (default_factory = list )
45+
2946 def connect (self , channels : GatewayChannels ) -> None :
3047 # NO-OP
3148 ...
3249
50+
51+ class MountAPIKeyMiddleware (MountAuthMiddleware ):
52+ api_key : str = Field (default = token_urlsafe (32 ), description = "API Key to use" )
53+ api_key_name : str = Field (default = "token" , description = "API Key to use" )
54+ api_key_timeout : timedelta = Field (description = "Cookie timeout for API Key authentication" , default = timedelta (hours = 12 ))
55+
56+ _instance_count = 0
57+
58+ @field_validator ("api_key_name" , mode = "before" )
59+ @classmethod
60+ def _validate_api_key_name (cls , value : str ) -> str :
61+ if not value :
62+ raise ValueError ("API Key name must be a non-empty string" )
63+ value = f"{ value .strip ().lower ()} -{ cls ._instance_count } "
64+ cls ._instance_count += 1
65+ return value
66+
3367 def rest (self , app : GatewayWebApp ) -> None :
3468 if app .settings .AUTHENTICATE :
35- # first, pull out the api key secret from the settings
36- self ._api_key_secret = app .settings .API_KEY
69+ # Use configuration to determine allowed routes
70+ # for this API key
71+ self ._calculate_auth (app )
72+
73+ # Setup the routes for authentication
74+ self ._setup_routes (app )
75+
76+
77+ def _calculate_auth (self , app : GatewayWebApp ) -> None :
78+ self ._enforced_channels = self .channels .select_from (app .gateway .channels_model )
79+
80+ # Fully form the url
81+ self ._api_str = app .settings .API_STR
3782
83+ def _setup_routes (self , app : GatewayWebApp ) -> None :
3884 # reinitialize header
39- api_key_query = APIKeyQuery (name = self ._api_key_name , auto_error = False )
40- api_key_header = APIKeyHeader (name = self ._api_key_name , auto_error = False )
41- api_key_cookie = APIKeyCookie (name = self ._api_key_name , auto_error = False )
85+ api_key_query = APIKeyQuery (name = self .api_key_name , auto_error = False )
86+ api_key_header = APIKeyHeader (name = self .api_key_name , auto_error = False )
87+ api_key_cookie = APIKeyCookie (name = self .api_key_name , auto_error = False )
4288
4389 # routers
4490 auth_router : APIRouter = app .get_router ("auth" )
4591 public_router : APIRouter = app .get_router ("public" )
4692
4793 # now mount middleware
4894 async def get_api_key (
95+ request : Request = None ,
4996 api_key_query : str = Security (api_key_query ),
5097 api_key_header : str = Security (api_key_header ),
5198 api_key_cookie : str = Security (api_key_cookie ),
5299 ):
53- if api_key_query == self ._api_key_secret or api_key_header == self ._api_key_secret or api_key_cookie == self ._api_key_secret :
54- return self ._api_key_secret
55- else :
56- raise HTTPException (
57- status_code = HTTP_403_FORBIDDEN ,
58- detail = self .unauthorized_status_message ,
59- )
60-
61- @auth_router .get ("/login" )
62- async def route_login_and_add_cookie (api_key : str = Depends (get_api_key )):
63- response = RedirectResponse (url = "/" )
64- response .set_cookie (
65- self ._api_key_name ,
66- value = api_key ,
67- domain = app .settings .AUTHENTICATION_DOMAIN ,
68- httponly = True ,
69- max_age = self .api_key_timeout .total_seconds (),
70- expires = self .api_key_timeout .total_seconds (),
71- )
72- return response
73-
74- @auth_router .get ("/logout" )
75- async def route_logout_and_remove_cookie ():
76- response = RedirectResponse (url = "/login" )
77- response .delete_cookie (self ._api_key_name , domain = app .settings .AUTHENTICATION_DOMAIN )
78- return response
79-
80- # I'm hand rolling these for now...
81- @public_router .get ("/login" , response_class = HTMLResponse , include_in_schema = False )
82- async def get_login_page (token : str = "" , request : Request = None ):
83- if token :
84- if token != "" :
85- return RedirectResponse (url = f"{ app .settings .API_V1_STR } /auth/login?token={ token } " )
86- return app .templates .TemplateResponse (
87- "login.html.j2" ,
88- {"request" : request , "api_key_name" : self ._api_key_name },
100+ if request is None :
101+ # If request is None, we are not in a request context, return None
102+ _log .warning ("API Key check: request is None, returning None" )
103+ return None
104+
105+ if hasattr (request .state , "auth" ):
106+ # Already authenticated, return the API key
107+ _log .info (f"API Key check: already authenticated, returning { self .api_key_name } " )
108+ return request .state .auth
109+
110+ resolved_path = request .url .path .rstrip ("/" ).replace (self ._api_str , "" ).lstrip ("/" ).rsplit ("/" , 1 )
111+
112+ if len (resolved_path ) == 1 :
113+ root = resolved_path [0 ]
114+ channel = ""
115+
116+ elif len (resolved_path ) > 1 :
117+ root = resolved_path [0 ]
118+ channel = resolved_path [1 ]
119+
120+ if self .enforce and root not in self .enforce :
121+ # Route not in enforce, allow
122+ _log .info (f"API Key check: { root } /{ channel } not in enforced list { self .enforce } , allowing" )
123+ return ""
124+
125+ if root == "controls" and not self .enforce_controls :
126+ # Controls route not enforced, allow
127+ _log .info (f"API Key check: root { root } not enforced, allowing" )
128+ return ""
129+
130+ # TODO
131+ if root in ("" , "auth" , "perspective" ) and not self .enforce_ui :
132+ # UI route not enforced, allow
133+ _log .info (f"API Key check: root { root } not enforced, allowing" )
134+ return ""
135+
136+ if root not in ("controls" , "auth" , "perspective" ) and channel and channel not in self ._enforced_channels :
137+ # Channel not in enforce, allow
138+ _log .info (f"API Key check: channel { root } /{ channel } not in enforced channels { self ._enforced_channels } , allowing" )
139+ return ""
140+
141+ # Else, enforce
142+ if api_key_query == self .api_key or api_key_header == self .api_key or api_key_cookie == self .api_key :
143+ # Return the API key secret to allow access
144+ _log .info (f"API Key check: { self .api_key_name } matched for { root } /{ channel } , allowing access" )
145+
146+ # NOTE: only set this if we are the one validating, not if we are ignoring
147+ request .state .auth = self .api_key
148+ return self .api_key
149+
150+ _log .warning (f"API Key check: { self .api_key_name } did not match, denying access" )
151+ raise HTTPException (
152+ status_code = HTTP_403_FORBIDDEN ,
153+ detail = self .unauthorized_status_message ,
89154 )
90155
91- @public_router .get ("/logout" , response_class = HTMLResponse , include_in_schema = False )
92- async def get_logout_page (request : Request = None ):
93- return app .templates .TemplateResponse ("logout.html.j2" , {"request" : request })
94-
95156 # add auth to all other routes
96157 app .add_middleware (Depends (get_api_key ))
97158
159+ if self .enforce_ui :
160+ @auth_router .get ("/login" )
161+ async def route_login_and_add_cookie (api_key : str = Depends (get_api_key )):
162+ if not api_key :
163+ raise HTTPException (
164+ status_code = HTTP_403_FORBIDDEN ,
165+ detail = self .unauthorized_status_message ,
166+ )
167+ response = RedirectResponse (url = "/" )
168+ response .set_cookie (
169+ self .api_key_name ,
170+ value = api_key ,
171+ domain = app .settings .AUTHENTICATION_DOMAIN ,
172+ httponly = True ,
173+ max_age = self .api_key_timeout .total_seconds (),
174+ expires = self .api_key_timeout .total_seconds (),
175+ )
176+ return response
177+
178+ @auth_router .get ("/logout" )
179+ async def route_logout_and_remove_cookie ():
180+ response = RedirectResponse (url = "/login" )
181+ response .delete_cookie (self .api_key_name , domain = app .settings .AUTHENTICATION_DOMAIN )
182+ return response
183+
184+ # I'm hand rolling these for now...
185+ @public_router .get ("/login" , response_class = HTMLResponse , include_in_schema = False )
186+ async def get_login_page (token : str = "" , request : Request = None ):
187+ if token and token != "" :
188+ return RedirectResponse (url = f"{ self ._api_str } /auth/login?token={ token } " )
189+ return app .templates .TemplateResponse (
190+ "login.html.j2" ,
191+ {"request" : request , "api_key_name" : self .api_key_name },
192+ )
193+
194+ @public_router .get ("/logout" , response_class = HTMLResponse , include_in_schema = False )
195+ async def get_logout_page (request : Request = None ):
196+ return app .templates .TemplateResponse ("logout.html.j2" , {"request" : request })
197+
198+
98199 @app .app .exception_handler (403 )
99200 async def custom_403_handler (request : Request = None , * args ):
100201 if "/api" in request .url .path :
@@ -110,7 +211,7 @@ async def custom_403_handler(request: Request = None, *args):
110211 "login.html.j2" ,
111212 {
112213 "request" : request ,
113- "api_key_name" : self ._api_key_name ,
214+ "api_key_name" : self .api_key_name ,
114215 "status_code" : 403 ,
115216 "detail" : self .unauthorized_status_message ,
116217 },
0 commit comments