Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token Validation + Fixes #218

Merged
merged 4 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 95 additions & 13 deletions engine/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,111 @@
from fastapi import Request, HTTPException

from requests import Session, adapters
from urllib3.util.retry import Retry
from cryptography.hazmat.primitives import serialization

import jwt
import time
import copy
import json

from app.routers.common.helper import (
cosmos_query
)

async def check_token_expired(request: Request):
now = int(time.time()) + 10
auth = request.headers.get('authorization')
from app.globals import globals

_session = None

async def fetch_jwks_keys():
global _session

if _session is None:
_session = Session()

retries = Retry(
total=5,
backoff_factor=0.1,
status_forcelist=[ 500, 502, 503, 504 ]
)

_session.mount('https://', adapters.HTTPAdapter(max_retries=retries))
_session.mount('http://', adapters.HTTPAdapter(max_retries=retries))

key_url = "https://" + globals.AUTHORITY_HOST + "/" + globals.TENANT_ID + "/discovery/v2.0/keys"

jwks = _session.get(key_url).json()

return jwks

async def get_token_auth_header(request: Request):
auth = request.headers.get("Authorization", None)

if not auth:
raise HTTPException(status_code=401, detail="Authorization header missing.")
raise HTTPException(status_code=401, detail="Authorization header is missing.")

parts = auth.split()

if parts[0].lower() != "bearer":
raise HTTPException(status_code=401, detail="Authorization header must start with 'Bearer'.")
elif len(parts) == 1:
raise HTTPException(status_code=401, detail="Token not found.")
elif len(parts) > 2:
raise HTTPException(status_code=401, detail="Authorization header must be of type Bearer token.")

token = parts[1]

user_assertion=auth.split(' ')[1]
return token

async def validate_token(request: Request):
try:
decoded = jwt.decode(user_assertion, options={"verify_signature": False})
except:
raise HTTPException(status_code=401, detail="Authorization token missing or invalid in header.")
token = await get_token_auth_header(request)
jwks = await fetch_jwks_keys()
unverified_header = jwt.get_unverified_header(token)

if(now >= int(decoded['exp'])):
raise HTTPException(status_code=401, detail="Token has expired.")
rsa_key = {}

request.state.tenant_id = decoded['tid']
for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
rsa_key = {
"kty": key["kty"],
"kid": key["kid"],
"use": key["use"],
"n": key["n"],
"e": key["e"]
}
except Exception:
raise HTTPException(status_code=401, detail="Unable to parse authorization token.")

await check_admin(request, decoded['oid'], decoded['tid'])
if rsa_key:
rsa_pem_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(rsa_key))
rsa_pem_key_bytes = rsa_pem_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)

try:
payload = jwt.decode(
token,
key=rsa_pem_key_bytes,
verify=True,
algorithms=["RS256"],
audience=globals.CLIENT_ID,
issuer="https://" + globals.AUTHORITY_HOST + "/" + globals.TENANT_ID + "/v2.0"
)
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired.")
except jwt.MissingRequiredClaimError:
raise HTTPException(status_code=401, detail="Incorrect token claims, please check the audience and issuer.")
except jwt.InvalidSignatureError:
raise HTTPException(status_code=401, detail="Invalid token signature.")
except Exception:
raise HTTPException(status_code=401, detail="Unable to parse authorization token.")
else:
raise HTTPException(status_code=401, detail="Unable to find appropriate signing key.")

request.state.tenant_id = payload['tid']

return payload

async def check_admin(request: Request, user_oid: str, user_tid: str):
admin_query = await cosmos_query("SELECT * FROM c WHERE c.type = 'admin'", user_tid)
Expand All @@ -44,6 +122,10 @@ async def check_admin(request: Request, user_oid: str, user_tid: str):

request.state.admin = True if is_admin else False

async def api_auth_checks(request: Request):
token_payload = await validate_token(request)
await check_admin(request, token_payload['oid'], token_payload['tid'])

async def get_admin(request: Request):
return request.state.admin

Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import uuid

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -36,7 +36,7 @@
router = APIRouter(
prefix="/admin",
tags=["admin"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def new_admin_db(admin_list, exclusion_list, tenant_id):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -45,7 +45,7 @@
router = APIRouter(
prefix="/azure",
tags=["azure"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

def str_to_list(input):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from netaddr import IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -35,7 +35,7 @@
router = APIRouter(
prefix="/internal",
tags=["internal"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def multi_helper(func, list, *args):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand Down Expand Up @@ -54,7 +54,7 @@
router = APIRouter(
prefix="/spaces",
tags=["spaces"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def valid_space_name_update(name, space_name, tenant_id):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_tenant_id
)

Expand All @@ -33,7 +33,7 @@
router = APIRouter(
prefix="/tools",
tags=["tools"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

@router.post(
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Union, List

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -39,7 +39,7 @@
router = APIRouter(
prefix="/users",
tags=["users"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def new_user(user_id, tenant_id):
Expand Down
Loading