Skip to content

Commit

Permalink
Merge pull request #60 from dmtzs/development
Browse files Browse the repository at this point in the history
Update flask_authgen_jwt.py with new functionalities and error handlers
  • Loading branch information
dmtzs authored Jan 1, 2023
2 parents 194796b + c129279 commit 8e88c25
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/flask_authgen_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class Core():
basic_auth_callback: Callable[[str, str], bool] = None
enc_dec_jwt_callback: dict[str, Union[bytes, str]] = None
get_user_roles_callback: list = None
get_user_roles_callback: list[str] = None
personal_credentials: tuple[str, str] = None

def enc_dec_jwt_config(self, func: Callable[[None], dict[str, Union[bytes, str]]]) -> Callable[[None], dict[str, Union[bytes, str]]]:
Expand Down Expand Up @@ -103,9 +103,10 @@ def ensure_sync(self, func: Callable) -> Callable:
return func

class GenJwt(Core):
def __init__(self, rsa_encrypt: bool = False) -> None:
def __init__(self, rsa_encrypt: bool = False, json_body_token: bool = False) -> None:
self.jwt_fields_attr: dict[str, datetime] = None
self.rsa_encrypt: bool = rsa_encrypt
self.json_body_token: bool = json_body_token

def __create_jwt_payload(self, bauth_credentials: dict[str, str]) -> dict[str, Union[str, datetime]]:
"""
Expand All @@ -114,6 +115,11 @@ def __create_jwt_payload(self, bauth_credentials: dict[str, str]) -> dict[str, U
"""
if not self.jwt_fields_attr:
self.gen_abort_error("jwt_claims decorator and function is not defined", 500)
if self.json_body_token:
if not request.is_json:
self.gen_abort_error("Missing JSON in request or not JSON format sent to endpoint", 400)
else:
bauth_credentials = request.get_json()
if self.personal_credentials is not None:
bauth_credentials[self.personal_credentials[0]] = bauth_credentials["username"]
bauth_credentials[self.personal_credentials[1]] = bauth_credentials["password"]
Expand All @@ -129,6 +135,8 @@ def __dec_set_basic_auth(self) -> Optional[bool]:
Method to decode and verify the basic auth credentials in the expected format
"""
auth_header = request.headers.get("Authorization")
if auth_header is None:
self.gen_abort_error("Authorization header is missing", 400)
auth_header = auth_header.split(" ")
if auth_header[0] != "Basic":
self.gen_abort_error("Authorization header must be Basic", 400)
Expand Down

0 comments on commit 8e88c25

Please sign in to comment.