Skip to content

Commit

Permalink
Merge pull request #218 from Azure/vnet-missing-fix
Browse files Browse the repository at this point in the history
Token Validation + Fixes
  • Loading branch information
DCMattyG authored Dec 22, 2023
2 parents 470bddc + cb52117 commit 64ef2d0
Show file tree
Hide file tree
Showing 12 changed files with 1,220 additions and 956 deletions.
108 changes: 95 additions & 13 deletions engine/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,111 @@
from fastapi import Request, HTTPException

from requests import Session, adapters
from urllib3.util.retry import Retry
from cryptography.hazmat.primitives import serialization

import jwt
import time
import copy
import json

from app.routers.common.helper import (
cosmos_query
)

async def check_token_expired(request: Request):
now = int(time.time()) + 10
auth = request.headers.get('authorization')
from app.globals import globals

_session = None

async def fetch_jwks_keys():
global _session

if _session is None:
_session = Session()

retries = Retry(
total=5,
backoff_factor=0.1,
status_forcelist=[ 500, 502, 503, 504 ]
)

_session.mount('https://', adapters.HTTPAdapter(max_retries=retries))
_session.mount('http://', adapters.HTTPAdapter(max_retries=retries))

key_url = "https://" + globals.AUTHORITY_HOST + "/" + globals.TENANT_ID + "/discovery/v2.0/keys"

jwks = _session.get(key_url).json()

return jwks

async def get_token_auth_header(request: Request):
auth = request.headers.get("Authorization", None)

if not auth:
raise HTTPException(status_code=401, detail="Authorization header missing.")
raise HTTPException(status_code=401, detail="Authorization header is missing.")

parts = auth.split()

if parts[0].lower() != "bearer":
raise HTTPException(status_code=401, detail="Authorization header must start with 'Bearer'.")
elif len(parts) == 1:
raise HTTPException(status_code=401, detail="Token not found.")
elif len(parts) > 2:
raise HTTPException(status_code=401, detail="Authorization header must be of type Bearer token.")

token = parts[1]

user_assertion=auth.split(' ')[1]
return token

async def validate_token(request: Request):
try:
decoded = jwt.decode(user_assertion, options={"verify_signature": False})
except:
raise HTTPException(status_code=401, detail="Authorization token missing or invalid in header.")
token = await get_token_auth_header(request)
jwks = await fetch_jwks_keys()
unverified_header = jwt.get_unverified_header(token)

if(now >= int(decoded['exp'])):
raise HTTPException(status_code=401, detail="Token has expired.")
rsa_key = {}

request.state.tenant_id = decoded['tid']
for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
rsa_key = {
"kty": key["kty"],
"kid": key["kid"],
"use": key["use"],
"n": key["n"],
"e": key["e"]
}
except Exception:
raise HTTPException(status_code=401, detail="Unable to parse authorization token.")

await check_admin(request, decoded['oid'], decoded['tid'])
if rsa_key:
rsa_pem_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(rsa_key))
rsa_pem_key_bytes = rsa_pem_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)

try:
payload = jwt.decode(
token,
key=rsa_pem_key_bytes,
verify=True,
algorithms=["RS256"],
audience=globals.CLIENT_ID,
issuer="https://" + globals.AUTHORITY_HOST + "/" + globals.TENANT_ID + "/v2.0"
)
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired.")
except jwt.MissingRequiredClaimError:
raise HTTPException(status_code=401, detail="Incorrect token claims, please check the audience and issuer.")
except jwt.InvalidSignatureError:
raise HTTPException(status_code=401, detail="Invalid token signature.")
except Exception:
raise HTTPException(status_code=401, detail="Unable to parse authorization token.")
else:
raise HTTPException(status_code=401, detail="Unable to find appropriate signing key.")

request.state.tenant_id = payload['tid']

return payload

async def check_admin(request: Request, user_oid: str, user_tid: str):
admin_query = await cosmos_query("SELECT * FROM c WHERE c.type = 'admin'", user_tid)
Expand All @@ -44,6 +122,10 @@ async def check_admin(request: Request, user_oid: str, user_tid: str):

request.state.admin = True if is_admin else False

async def api_auth_checks(request: Request):
token_payload = await validate_token(request)
await check_admin(request, token_payload['oid'], token_payload['tid'])

async def get_admin(request: Request):
return request.state.admin

Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import uuid

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -36,7 +36,7 @@
router = APIRouter(
prefix="/admin",
tags=["admin"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def new_admin_db(admin_list, exclusion_list, tenant_id):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -45,7 +45,7 @@
router = APIRouter(
prefix="/azure",
tags=["azure"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

def str_to_list(input):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from netaddr import IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -35,7 +35,7 @@
router = APIRouter(
prefix="/internal",
tags=["internal"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def multi_helper(func, list, *args):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand Down Expand Up @@ -54,7 +54,7 @@
router = APIRouter(
prefix="/spaces",
tags=["spaces"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def valid_space_name_update(name, space_name, tenant_id):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_tenant_id
)

Expand All @@ -33,7 +33,7 @@
router = APIRouter(
prefix="/tools",
tags=["tools"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

@router.post(
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Union, List

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -39,7 +39,7 @@
router = APIRouter(
prefix="/users",
tags=["users"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def new_user(user_id, tenant_id):
Expand Down
Loading

0 comments on commit 64ef2d0

Please sign in to comment.