-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feat/add-translation
Showing
32 changed files
with
1,017 additions
and
412 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import os | ||
from typing import Literal, Optional, cast | ||
|
||
from fastapi import Request, Response | ||
from fastapi.exceptions import HTTPException | ||
from fastapi.security.base import SecurityBase | ||
from fastapi.security.utils import get_authorization_scheme_param | ||
from starlette.status import HTTP_401_UNAUTHORIZED | ||
|
||
""" Module level cookie settings. """ | ||
_cookie_samesite = cast( | ||
Literal["lax", "strict", "none"], | ||
os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax"), | ||
) | ||
|
||
assert ( | ||
_cookie_samesite | ||
in [ | ||
"lax", | ||
"strict", | ||
"none", | ||
] | ||
), "Invalid value for CHAINLIT_COOKIE_SAMESITE. Must be one of 'lax', 'strict' or 'none'." | ||
_cookie_secure = _cookie_samesite == "none" | ||
|
||
_auth_cookie_lifetime = 60 * 60 # 1 hour | ||
_state_cookie_lifetime = 3 * 60 # 3m | ||
_auth_cookie_name = "access_token" | ||
_state_cookie_name = "oauth_state" | ||
|
||
|
||
class OAuth2PasswordBearerWithCookie(SecurityBase): | ||
""" | ||
OAuth2 password flow with cookie support with fallback to bearer token. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tokenUrl: str, | ||
scheme_name: Optional[str] = None, | ||
auto_error: bool = True, | ||
): | ||
self.tokenUrl = tokenUrl | ||
self.scheme_name = scheme_name or self.__class__.__name__ | ||
self.auto_error = auto_error | ||
|
||
async def __call__(self, request: Request) -> Optional[str]: | ||
# First try to get the token from the cookie | ||
token = request.cookies.get(_auth_cookie_name) | ||
|
||
# If no cookie, try the Authorization header as fallback | ||
if not token: | ||
# TODO: Only bother to check if cookie auth is explicitly disabled. | ||
authorization = request.headers.get("Authorization") | ||
if authorization: | ||
scheme, token = get_authorization_scheme_param(authorization) | ||
if scheme.lower() != "bearer": | ||
if self.auto_error: | ||
raise HTTPException( | ||
status_code=HTTP_401_UNAUTHORIZED, | ||
detail="Invalid authentication credentials", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
else: | ||
return None | ||
else: | ||
if self.auto_error: | ||
raise HTTPException( | ||
status_code=HTTP_401_UNAUTHORIZED, | ||
detail="Not authenticated", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
else: | ||
return None | ||
|
||
return token | ||
|
||
|
||
def set_auth_cookie(response: Response, token: str): | ||
""" | ||
Helper function to set the authentication cookie with secure parameters | ||
""" | ||
|
||
response.set_cookie( | ||
key=_auth_cookie_name, | ||
value=token, | ||
httponly=True, | ||
secure=_cookie_secure, | ||
samesite=_cookie_samesite, | ||
max_age=_auth_cookie_lifetime, | ||
path="/", # Why is path set here and not below? | ||
) | ||
|
||
|
||
def clear_auth_cookie(response: Response): | ||
""" | ||
Helper function to clear the authentication cookie | ||
""" | ||
response.delete_cookie(key=_auth_cookie_name, path="/") | ||
|
||
|
||
def set_oauth_state_cookie(response: Response, token: str): | ||
response.set_cookie( | ||
_state_cookie_name, | ||
token, | ||
httponly=True, | ||
samesite=_cookie_samesite, | ||
secure=_cookie_secure, | ||
max_age=_state_cookie_lifetime, | ||
) | ||
|
||
|
||
def validate_oauth_state_cookie(request: Request, state: str): | ||
"""Check the state from the oauth provider against the browser cookie.""" | ||
|
||
oauth_state = request.cookies.get(_state_cookie_name) | ||
|
||
if oauth_state != state: | ||
raise Exception("oauth state does not correspond") | ||
|
||
|
||
def clear_oauth_state_cookie(response: Response): | ||
"""Oauth complete, delete state token.""" | ||
response.delete_cookie(_state_cookie_name) # Do we set path here? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import datetime | ||
import os | ||
from typing import Any, Dict, Optional | ||
|
||
import jwt as pyjwt | ||
|
||
from chainlit.config import config | ||
from chainlit.user import User | ||
|
||
|
||
def get_jwt_secret() -> Optional[str]: | ||
return os.environ.get("CHAINLIT_AUTH_SECRET") | ||
|
||
|
||
def create_jwt(data: User) -> str: | ||
to_encode: Dict[str, Any] = data.to_dict() | ||
to_encode.update( | ||
{ | ||
"exp": datetime.datetime.utcnow() | ||
+ datetime.timedelta(seconds=config.project.user_session_timeout), | ||
} | ||
) | ||
secret = get_jwt_secret() | ||
assert secret | ||
encoded_jwt = pyjwt.encode(to_encode, secret, algorithm="HS256") | ||
return encoded_jwt | ||
|
||
|
||
def decode_jwt(token: str) -> User: | ||
dict = pyjwt.decode( | ||
token, | ||
get_jwt_secret(), | ||
algorithms=["HS256"], | ||
options={"verify_signature": True}, | ||
) | ||
del dict["exp"] | ||
return User(**dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.