From cc79b79d99b21327034aeceb9b873dfb565cacc3 Mon Sep 17 00:00:00 2001 From: Tochukwu Date: Fri, 17 Mar 2023 06:42:31 +0100 Subject: [PATCH] More customisation control in controller schemas (#26) * refactored schema and gave them a more proper name * Added controller schemas to settings. For more dynamic support * Added schema_control helper functions and apply schemas to controller * removed ninja_schema dependency * removed ninja_schema dependency * Added some test for custom controller schema * fixed failing test * linting * added doc and some refactoring * added more test * check_authentication_fix --- .pre-commit-config.yaml | 51 ++- Makefile | 8 +- README.md | 1 - docs/customizing_token_claims.md | 75 ++++- docs/development_and_contributing.md | 21 +- docs/getting_started.md | 4 +- docs/settings.md | 12 +- mkdocs.yml | 7 +- ninja_jwt/__init__.py | 2 +- ninja_jwt/controller.py | 112 ++++--- ninja_jwt/schema.py | 206 ++++++++---- ninja_jwt/schema_control.py | 94 ++++++ ninja_jwt/settings.py | 21 +- ninja_jwt/tokens.py | 6 +- pyproject.toml | 6 +- tests/test_custom_schema.py | 459 +++++++++++++++++++++++++++ tests/test_token_blacklist.py | 6 +- 17 files changed, 930 insertions(+), 161 deletions(-) create mode 100644 ninja_jwt/schema_control.py create mode 100644 tests/test_custom_schema.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5675c2637..0fbadcd76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,37 +1,37 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 'v4.3.0' + rev: v2.3.0 hooks: - id: check-merge-conflict - repo: https://github.com/asottile/yesqa rev: v1.3.0 hooks: - id: yesqa -- repo: https://github.com/pycqa/isort - rev: '5.10.1' +- repo: local hooks: - - id: isort - args: ["--profile", "black"] -- repo: https://github.com/psf/black - rev: '22.6.0' - hooks: - - id: black - language_version: python3 # Should be a command that runs python3.6+ + - id: code_formatting + args: [] + name: Code Formatting + entry: "make fmt" + types: [python] + language_version: python3.8 + language: python + - id: code_linting + args: [ ] + name: Code Linting + entry: "make lint" + types: [ python ] + language_version: python3.8 + language: python - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 'v4.3.0' + rev: v2.3.0 hooks: - id: end-of-file-fixer exclude: >- - ^docs/[^/]*\.svg$ + ^examples/[^/]*\.svg$ - id: requirements-txt-fixer - id: trailing-whitespace types: [python] - - id: file-contents-sorter - files: | - CONTRIBUTORS.txt| - docs/spelling_wordlist.txt| - .gitignore| - .gitattributes - id: check-case-conflict - id: check-json - id: check-xml @@ -43,19 +43,4 @@ repos: - id: check-added-large-files - id: check-symlinks - id: debug-statements - - id: detect-aws-credentials - args: ['--allow-missing-credentials'] - - id: detect-private-key exclude: ^tests/ -- repo: https://github.com/asottile/pyupgrade - rev: 'v2.37.1' - hooks: - - id: pyupgrade - args: ['--py37-plus', '--keep-mock'] - -- repo: https://github.com/Lucas-C/pre-commit-hooks-markup - rev: v1.0.1 - hooks: - - id: rst-linter - files: >- - ^[^/]+[.]rst$ diff --git a/Makefile b/Makefile index 780a9b046..a6af208d6 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,9 @@ clean: ## Removing cached python compiled files find . -name __pycache__ | xargs rm -rfv install: ## Install dependencies + make clean flit install --deps develop --symlink + pre-commit install -f lint: ## Run code linters make clean @@ -35,4 +37,8 @@ test-cov: ## Run tests with coverage doc-deploy: ## Run Deploy Documentation make clean - mkdocs gh-deploy --force \ No newline at end of file + mkdocs gh-deploy --force + +doc-serve: ## Run Deploy Documentation + make clean + mkdocs serve diff --git a/README.md b/README.md index 72e480834..54947b878 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,6 @@ For full documentation, [visit](https://eadwincode.github.io/django-ninja-jwt/). - Python >= 3.6 - Django >= 2.1 - Django-Ninja >= 0.16.1 -- Ninja-Schema >= 0.12.8 - Django-Ninja-Extra >= 0.14.2 ## Example diff --git a/docs/customizing_token_claims.md b/docs/customizing_token_claims.md index 23c4f938b..5290615cd 100644 --- a/docs/customizing_token_claims.md +++ b/docs/customizing_token_claims.md @@ -5,13 +5,14 @@ views, create a subclass for the desired controller as well as a subclass for its corresponding serializer. Here\'s an example : !!! info - if you are interested in Asynchronous version of the class, checkout `AsyncNinjaJWTDefaultController` and `AsyncNinjaJWTSlidingController` + if you are interested in Asynchronous version of the class, use `AsyncNinjaJWTDefaultController` and `AsyncNinjaJWTSlidingController`. + Also note, it's only available for Django versions that supports asynchronous actions. ```python -from ninja_jwt.schema import TokenObtainPairSerializer +from ninja_jwt.schema import TokenObtainPairInputSchema from ninja_jwt.controller import TokenObtainPairController from ninja_extra import api_controller, route -from ninja_schema import Schema +from ninja import Schema class UserSchema(Schema): @@ -25,12 +26,13 @@ class MyTokenObtainPairOutSchema(Schema): user: UserSchema -class MyTokenObtainPairSchema(TokenObtainPairSerializer): +class MyTokenObtainPairSchema(TokenObtainPairInputSchema): def output_schema(self): out_dict = self.dict(exclude={"password"}) out_dict.update(user=UserSchema.from_orm(self._user)) return MyTokenObtainPairOutSchema(**out_dict) + @api_controller('/token', tags=['Auth']) class MyTokenObtainPairController(TokenObtainPairController): @route.post( @@ -49,7 +51,6 @@ Here is an example ```python from ninja import router -from ninja_schema import Schema router = router('/token') @@ -67,3 +68,67 @@ from ninja import NinjaAPI api = NinjaAPI() api.add_router('', tags=['Auth'], router=router) ``` + + +### Controller Schema Swapping + +You can now swap controller schema in `NINJA_JWT` settings without having to inherit or override Ninja JWT controller function. + +All controller input schema must inherit from `ninja_jwt.schema.InputSchemaMixin` and token generating schema should inherit +from `ninja_jwt.schema.TokenObtainInputSchemaBase` or `ninja_jwt.schema.TokenInputSchemaMixin` if you want to have more control. + +Using the example above: + +```python +# project/schema.py +from typing import Type, Dict +from ninja_jwt.schema import TokenObtainInputSchemaBase +from ninja import Schema +from ninja_jwt.tokens import RefreshToken + +class UserSchema(Schema): + first_name: str + email: str + + +class MyTokenObtainPairOutSchema(Schema): + refresh: str + access: str + user: UserSchema + + +class MyTokenObtainPairInputSchema(TokenObtainInputSchemaBase): + @classmethod + def get_response_schema(cls) -> Type[Schema]: + return MyTokenObtainPairOutSchema + + @classmethod + def get_token(cls, user) -> Dict: + values = {} + refresh = RefreshToken.for_user(user) + values["refresh"] = str(refresh) + values["access"] = str(refresh.access_token) + values.update(user=UserSchema.from_orm(user)) # this will be needed when creating output schema + return values +``` + +In the `MyTokenObtainPairInputSchema` we override `get_token` to define our token and some data needed for our output schema. +We also override `get_response_schema` to define our output schema `MyTokenObtainPairOutSchema`. + +Next, we apply the `MyTokenObtainPairInputSchema` schema to controller. This is simply done in `NINJA_JWT` settings. + +```python +# project/settings.py + +NINJA_JWT = { + 'TOKEN_OBTAIN_PAIR_INPUT_SCHEMA': 'project.schema.MyTokenObtainPairInputSchema', +} +``` +Other swappable schemas can be found in [settings](../settings) + +![token_customization_git](./img/token_customize.gif) + +!!! Note + `Controller Schema Swapping` is only available from **v5.2.4** + + diff --git a/docs/development_and_contributing.md b/docs/development_and_contributing.md index 2168620b7..1b7d6a34d 100644 --- a/docs/development_and_contributing.md +++ b/docs/development_and_contributing.md @@ -3,17 +3,26 @@ To do development work for Ninja JWT, make your own fork on Github, clone it locally, make and activate a virtualenv for it, then from within the project directory: -``` {.sourceCode .bash} -make install +After that, install flit + +```shell +$(venv) pip install flit +``` + +Install development libraries and pre-commit hooks for code linting and styles + +```shell +$(venv) make install ``` To run the tests: -``` {.sourceCode .bash} -make test +```shell +$(venv) make test ``` + To run the tests with coverage: -``` {.sourceCode .bash} -make test-cov +```shell +$(venv) make test-cov ``` diff --git a/docs/getting_started.md b/docs/getting_started.md index 8abacfb71..8ea72656b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -3,7 +3,6 @@ - Python >= 3.6 - Django >= 2.1 - Django-Ninja >= 0.16.1 -- Ninja-Schema >= 0.12.2 - Django-Ninja-Extra >= 0.11.0 These are the officially supported python and package versions. Other @@ -12,7 +11,6 @@ see what is possible. Installation ============ - Ninja JWT can be installed with pip: pip install django-ninja-jwt @@ -112,4 +110,4 @@ extra in the `django-ninja-jwt` requirement: The `django-ninja-jwt[crypto]` format is recommended in requirements files in projects using `Ninja JWT`, as a separate `cryptography` requirement line may later be mistaken for an unused requirement and removed. -[cryptography](https://cryptography.io) \ No newline at end of file +[cryptography](https://cryptography.io) diff --git a/docs/settings.md b/docs/settings.md index 436ad2f26..6bff056f0 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -37,6 +37,17 @@ NINJA_JWT = { 'SLIDING_TOKEN_REFRESH_EXP_CLAIM': 'refresh_exp', 'SLIDING_TOKEN_LIFETIME': timedelta(minutes=5), 'SLIDING_TOKEN_REFRESH_LIFETIME': timedelta(days=1), + + # For Controller Schemas + # FOR OBTAIN PAIR + 'TOKEN_OBTAIN_PAIR_INPUT_SCHEMA': "ninja_jwt.schema.TokenObtainPairInputSchema", + 'TOKEN_OBTAIN_PAIR_REFRESH_INPUT_SCHEMA': "ninja_jwt.schema.TokenRefreshInputSchema", + # FOR SLIDING TOKEN + 'TOKEN_OBTAIN_SLIDING_INPUT_SCHEMA': "ninja_jwt.schema.TokenObtainSlidingInputSchema", + 'TOKEN_OBTAIN_SLIDING_REFRESH_INPUT_SCHEMA':"ninja_jwt.schema.TokenRefreshSlidingInputSchema", + + 'TOKEN_BLACKLIST_INPUT_SCHEMA': "ninja_jwt.schema.TokenBlacklistInputSchema", + 'TOKEN_VERIFY_INPUT_SCHEMA': "ninja_jwt.schema.TokenVerifyInputSchema", } ``` @@ -247,4 +258,3 @@ More about this in the "Sliding tokens" section below. The claim name that is used to store the expiration time of a sliding token's refresh period. More about this in the "Sliding tokens" section below. - diff --git a/mkdocs.yml b/mkdocs.yml index 36d9feabe..d298a9efa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -3,7 +3,7 @@ site_description: Django Ninja JWT - A Simple JWT plugin for Django-Ninja. site_url: https://eadwincode.github.io/django-ninja-jwt/ repo_name: eadwinCode/django-ninja-jwt repo_url: https://github.com/eadwinCode/django-ninja-jwt -edit_uri: '' +edit_uri: 'https://github.com/eadwinCode/django-ninja-jwt/docs' theme: name: material @@ -44,3 +44,8 @@ nav: - Development and Contributing: development_and_contributing.md - Experimental Feature: experimental_features.md #- ninja_jwt package: index.md + +markdown_extensions: +- codehilite +- admonition +- pymdownx.superfences diff --git a/ninja_jwt/__init__.py b/ninja_jwt/__init__.py index c0ebc17bb..91e7eecf6 100644 --- a/ninja_jwt/__init__.py +++ b/ninja_jwt/__init__.py @@ -1,3 +1,3 @@ """Django Ninja JWT - JSON Web Token for Django-Ninja""" -__version__ = "5.2.2" +__version__ = "5.2.4" diff --git a/ninja_jwt/controller.py b/ninja_jwt/controller.py index 00a195e3f..b2d453d72 100644 --- a/ninja_jwt/controller.py +++ b/ninja_jwt/controller.py @@ -1,9 +1,10 @@ import django +from ninja import Schema from ninja_extra import ControllerBase, api_controller, http_post from ninja_extra.permissions import AllowAny -from ninja_schema import Schema -from ninja_jwt import schema +from ninja_jwt.schema_control import SchemaControl +from ninja_jwt.settings import api_settings exports = [ "TokenVerificationController", @@ -29,37 +30,51 @@ __all__ = exports +schema = SchemaControl(api_settings) + + class TokenVerificationController(ControllerBase): auto_import = False - @http_post("/verify", response={200: Schema}, url_name="token_verify") - def verify_token(self, token: schema.TokenVerifySerializer): - return {} + @http_post( + "/verify", + response={200: Schema}, + url_name="token_verify", + ) + def verify_token(self, token: schema.verify_schema): + return token.to_response_schema() class TokenBlackListController(ControllerBase): auto_import = False - @http_post("/blacklist", response={200: Schema}, url_name="token_blacklist") - def blacklist_token(self, refresh: schema.TokenBlacklistSerializer): - return {} + @http_post( + "/blacklist", + response={200: Schema}, + url_name="token_blacklist", + ) + def blacklist_token(self, refresh: schema.blacklist_schema): + return refresh.to_response_schema() class TokenObtainPairController(ControllerBase): auto_import = False @http_post( - "/pair", response=schema.TokenObtainPairOutput, url_name="token_obtain_pair" + "/pair", + response=schema.obtain_pair_schema.get_response_schema(), + url_name="token_obtain_pair", ) - def obtain_token(self, user_token: schema.TokenObtainPairSerializer): - return user_token.output_schema() + def obtain_token(self, user_token: schema.obtain_pair_schema): + return user_token.to_response_schema() @http_post( - "/refresh", response=schema.TokenRefreshSerializer, url_name="token_refresh" + "/refresh", + response=schema.obtain_pair_refresh_schema.get_response_schema(), + url_name="token_refresh", ) - def refresh_token(self, refresh_token: schema.TokenRefreshSchema): - refresh = schema.TokenRefreshSerializer(**refresh_token.dict()) - return refresh + def refresh_token(self, refresh_token: schema.obtain_pair_refresh_schema): + return refresh_token.to_response_schema() class TokenObtainSlidingController(TokenObtainPairController): @@ -67,20 +82,19 @@ class TokenObtainSlidingController(TokenObtainPairController): @http_post( "/sliding", - response=schema.TokenObtainSlidingOutput, + response=schema.obtain_sliding_schema.get_response_schema(), url_name="token_obtain_sliding", ) - def obtain_token(self, user_token: schema.TokenObtainSlidingSerializer): - return user_token.output_schema() + def obtain_token(self, user_token: schema.obtain_sliding_schema): + return user_token.to_response_schema() @http_post( "/sliding/refresh", - response=schema.TokenRefreshSlidingSerializer, + response=schema.obtain_sliding_refresh_schema.get_response_schema(), url_name="token_refresh_sliding", ) - def refresh_token(self, refresh_token: schema.TokenRefreshSlidingSchema): - refresh = schema.TokenRefreshSlidingSerializer(**refresh_token.dict()) - return refresh + def refresh_token(self, refresh_token: schema.obtain_sliding_refresh_schema): + return refresh_token.to_response_schema() @api_controller("/token", permissions=[AllowAny], tags=["token"]) @@ -106,51 +120,61 @@ class NinjaJWTSlidingController( from asgiref.sync import sync_to_async class AsyncTokenVerificationController(TokenVerificationController): - @http_post("/verify", response={200: Schema}, url_name="token_verify") - async def verify_token(self, token: schema.TokenVerifySerializer): - return {} + @http_post( + "/verify", + response={200: Schema}, + url_name="token_verify", + ) + async def verify_token(self, token: schema.verify_schema): + return token.to_response_schema() class AsyncTokenBlackListController(TokenBlackListController): auto_import = False - @http_post("/blacklist", response={200: Schema}, url_name="token_blacklist") - async def blacklist_token(self, refresh: schema.TokenBlacklistSerializer): - return {} + @http_post( + "/blacklist", + response={200: Schema}, + url_name="token_blacklist", + ) + async def blacklist_token(self, refresh: schema.blacklist_schema): + return refresh.to_response_schema() class AsyncTokenObtainPairController(TokenObtainPairController): @http_post( - "/pair", response=schema.TokenObtainPairOutput, url_name="token_obtain_pair" + "/pair", + response=schema.obtain_pair_schema.get_response_schema(), + url_name="token_obtain_pair", ) - async def obtain_token(self, user_token: schema.TokenObtainPairSerializer): - return user_token.output_schema() + async def obtain_token(self, user_token: schema.obtain_pair_schema): + return user_token.to_response_schema() @http_post( - "/refresh", response=schema.TokenRefreshSerializer, url_name="token_refresh" + "/refresh", + response=schema.obtain_pair_refresh_schema.get_response_schema(), + url_name="token_refresh", ) - async def refresh_token(self, refresh_token: schema.TokenRefreshSchema): - refresh = await sync_to_async(schema.TokenRefreshSerializer)( - **refresh_token.dict() - ) + async def refresh_token(self, refresh_token: schema.obtain_pair_refresh_schema): + refresh = await sync_to_async(refresh_token.to_response_schema)() return refresh class AsyncTokenObtainSlidingController(TokenObtainSlidingController): @http_post( "/sliding", - response=schema.TokenObtainSlidingOutput, + response=schema.obtain_sliding_schema.get_response_schema(), url_name="token_obtain_sliding", ) - async def obtain_token(self, user_token: schema.TokenObtainSlidingSerializer): - return user_token.output_schema() + async def obtain_token(self, user_token: schema.obtain_sliding_schema): + return user_token.to_response_schema() @http_post( "/sliding/refresh", - response=schema.TokenRefreshSlidingSerializer, + response=schema.obtain_sliding_refresh_schema.get_response_schema(), url_name="token_refresh_sliding", ) - async def refresh_token(self, refresh_token: schema.TokenRefreshSlidingSchema): - refresh = await sync_to_async(schema.TokenRefreshSlidingSerializer)( - **refresh_token.dict() - ) + async def refresh_token( + self, refresh_token: schema.obtain_sliding_refresh_schema + ): + refresh = await sync_to_async(refresh_token.to_response_schema)() return refresh @api_controller("/token", permissions=[AllowAny], tags=["token"]) diff --git a/ninja_jwt/schema.py b/ninja_jwt/schema.py index 6e8ad72ab..bc65eaa1f 100644 --- a/ninja_jwt/schema.py +++ b/ninja_jwt/schema.py @@ -1,17 +1,18 @@ -from typing import Dict, Optional, Type, cast +import warnings +from typing import Any, Callable, Dict, Optional, Type, Union, cast from django.conf import settings from django.contrib.auth import authenticate, get_user_model from django.contrib.auth.models import AbstractUser, update_last_login from django.utils.translation import gettext_lazy as _ -from ninja_schema import ModelSchema, Schema +from ninja import ModelSchema, Schema from pydantic import root_validator +import ninja_jwt.exceptions as exceptions from ninja_jwt.utils import token_error -from . import exceptions from .settings import api_settings -from .tokens import RefreshToken, SlidingToken, Token, UntypedToken +from .tokens import RefreshToken, SlidingToken, UntypedToken if api_settings.BLACKLIST_AFTER_ROTATION: from .token_blacklist.models import BlacklistedToken @@ -22,22 +23,37 @@ class AuthUserSchema(ModelSchema): class Config: model = get_user_model() - include = [user_name_field] + model_fields = [user_name_field] -class TokenObtainSerializer(ModelSchema): - class Config: - model = get_user_model() - include = ["password", user_name_field] +class InputSchemaMixin: + dict: Callable + + @classmethod + def get_response_schema(cls) -> Type[Schema]: + raise NotImplementedError("Must implement `get_response_schema`") - _user: Optional[Type[AbstractUser]] = None + def to_response_schema(self): + _schema_type = self.get_response_schema() + return _schema_type(**self.dict()) + + +class TokenInputSchemaMixin(InputSchemaMixin): + + _user: Optional[AbstractUser] = None _default_error_messages = { "no_active_account": _("No active account found with the given credentials") } - @root_validator(pre=True) - def validate_inputs(cls, values: Dict) -> dict: + @classmethod + def check_user_authentication_rule( + cls, user: Optional[Union[AbstractUser, Any]], values: Dict + ) -> bool: + return api_settings.USER_AUTHENTICATION_RULE(user) + + @classmethod + def validate_values(cls, values: Dict) -> dict: if user_name_field not in values and "password" not in values: raise exceptions.ValidationError( { @@ -54,79 +70,108 @@ def validate_inputs(cls, values: Dict) -> dict: if not values.get("password"): raise exceptions.ValidationError({"password": "password is required"}) - cls._user = authenticate(**values) + _user = authenticate(**values) - if not api_settings.USER_AUTHENTICATION_RULE(cls._user): + if not cls.check_user_authentication_rule(_user, values): raise exceptions.AuthenticationFailed( cls._default_error_messages["no_active_account"] ) + cls._user = _user + return values - def output_schema(self) -> Type[Schema]: - raise NotImplementedError( - "Must implement `output_schema` method for `TokenObtainSerializer` subclasses" + def output_schema(self) -> Schema: + warnings.warn( + "output_schema() is deprecated in favor of " "to_response_schema()", + DeprecationWarning, + stacklevel=2, ) + return self.to_response_schema() @classmethod - def get_token(cls, user: Type[AbstractUser]) -> Type[Token]: + def get_token(cls, user: AbstractUser) -> Dict: raise NotImplementedError( "Must implement `get_token` method for `TokenObtainSerializer` subclasses" ) -class TokenObtainPairOutput(AuthUserSchema): - refresh: str - access: str +class TokenObtainInputSchemaBase(ModelSchema, TokenInputSchemaMixin): + class Config: + model = get_user_model() + model_fields = ["password", user_name_field] + @root_validator(pre=True) + def validate_inputs(cls, values: Dict) -> dict: + return cls.validate_values(values) + + @root_validator + def post_validate(cls, values: Dict) -> dict: + return cls.post_validate_schema(values) -class TokenObtainPairSerializer(TokenObtainSerializer): @classmethod - def get_token(cls, user: Type[AbstractUser]) -> Type[Token]: - return RefreshToken.for_user(user) + def post_validate_schema(cls, values: Dict) -> dict: + """ + This is a post validate process which is common for any token generating schema. + :param values: + :return: + """ + # get_token can return values that wants to apply to `OutputSchema` + data = cls.get_token(cls._user) - @root_validator - def validate_schema(cls, values: Dict) -> dict: - refresh = cls.get_token(cls._user) - refresh = cast(RefreshToken, refresh) + if not isinstance(data, dict): + raise Exception("`get_token` must return a `typing.Dict` type.") - values["refresh"] = str(refresh) - values["access"] = str(refresh.access_token) + values.update(data) if api_settings.UPDATE_LAST_LOGIN: update_last_login(None, cls._user) return values - def output_schema(self): - return TokenObtainPairOutput(**self.dict(exclude={"password"})) + def to_response_schema(self): + _schema_type = self.get_response_schema() + return _schema_type(**self.dict(exclude={"password"})) -class TokenObtainSlidingOutput(AuthUserSchema): - token: str +class TokenObtainPairOutputSchema(AuthUserSchema): + refresh: str + access: str -class TokenObtainSlidingSerializer(TokenObtainSerializer): +class TokenObtainPairInputSchema(TokenObtainInputSchemaBase): @classmethod - def get_token(cls, user: Type[AbstractUser]) -> Type[Token]: - return SlidingToken.for_user(user) + def get_response_schema(cls) -> Type[Schema]: + return TokenObtainPairOutputSchema - @root_validator - def validate_schema(cls, values: Dict) -> dict: - token = cls.get_token(cls._user) + @classmethod + def get_token(cls, user: AbstractUser) -> Dict: + values = {} + refresh = RefreshToken.for_user(user) + refresh = cast(RefreshToken, refresh) + values["refresh"] = str(refresh) + values["access"] = str(refresh.access_token) + return values - values["token"] = str(token) - if api_settings.UPDATE_LAST_LOGIN and cls._user: - update_last_login(cls, cls._user) +class TokenObtainSlidingOutputSchema(AuthUserSchema): + token: str - return values - def output_schema(self): - return TokenObtainSlidingOutput(**self.dict(exclude={"password"})) +class TokenObtainSlidingInputSchema(TokenObtainInputSchemaBase): + @classmethod + def get_response_schema(cls) -> Type: + return TokenObtainSlidingOutputSchema + + @classmethod + def get_token(cls, user: AbstractUser) -> Dict: + values = {} + slide_token = SlidingToken.for_user(user) + values["token"] = str(slide_token) + return values -class TokenRefreshSchema(Schema): +class TokenRefreshInputSchema(Schema, InputSchemaMixin): refresh: str @root_validator @@ -135,8 +180,12 @@ def validate_schema(cls, values: Dict) -> dict: raise exceptions.ValidationError({"refresh": "token is required"}) return values + @classmethod + def get_response_schema(cls) -> Type[Schema]: + return TokenRefreshOutputSchema + -class TokenRefreshSerializer(Schema): +class TokenRefreshOutputSchema(Schema): refresh: str access: Optional[str] @@ -165,11 +214,11 @@ def validate_schema(cls, values: Dict) -> dict: refresh.set_iat() data["refresh"] = str(refresh) - - return data + values.update(data) + return values -class TokenRefreshSlidingSchema(Schema): +class TokenRefreshSlidingInputSchema(Schema, InputSchemaMixin): token: str @root_validator @@ -178,8 +227,12 @@ def validate_schema(cls, values: Dict) -> dict: raise exceptions.ValidationError({"token": "token is required"}) return values + @classmethod + def get_response_schema(cls) -> Type[Schema]: + return TokenRefreshSlidingOutputSchema + -class TokenRefreshSlidingSerializer(Schema): +class TokenRefreshSlidingOutputSchema(Schema): token: str @root_validator @@ -197,11 +250,11 @@ def validate_schema(cls, values: Dict) -> dict: # Update the "exp" and "iat" claims token.set_exp() token.set_iat() - - return {"token": str(token)} + values.update({"token": str(token)}) + return values -class TokenVerifySerializer(Schema): +class TokenVerifyInputSchema(Schema, InputSchemaMixin): token: str @root_validator @@ -221,8 +274,15 @@ def validate_schema(cls, values: Dict) -> dict: return values + @classmethod + def get_response_schema(cls) -> Type[Schema]: + return Schema + + def to_response_schema(self): + return {} + -class TokenBlacklistSerializer(Schema): +class TokenBlacklistInputSchema(Schema, InputSchemaMixin): refresh: str @root_validator @@ -236,3 +296,37 @@ def validate_schema(cls, values: Dict) -> dict: except AttributeError: pass return values + + @classmethod + def get_response_schema(cls) -> Type[Schema]: + return Schema + + def to_response_schema(self): + return {} + + +__deprecated__ = { + "TokenBlacklistSerializer": TokenBlacklistInputSchema, + "TokenVerifySerializer": TokenVerifyInputSchema, + "TokenRefreshSlidingSerializer": TokenRefreshSlidingOutputSchema, + "TokenRefreshSlidingSchema": TokenRefreshSlidingInputSchema, + "TokenRefreshSerializer": TokenRefreshOutputSchema, + "TokenRefreshSchema": TokenRefreshInputSchema, + "TokenObtainSlidingOutput": TokenObtainSlidingOutputSchema, + "TokenObtainSerializer": TokenObtainInputSchemaBase, + "TokenObtainPairOutput": TokenObtainPairOutputSchema, + "TokenObtainPairSerializer": TokenObtainPairInputSchema, + "TokenObtainSlidingSerializer": TokenObtainSlidingInputSchema, +} + + +def __getattr__(name: str) -> Any: # pragma: no cover + if name in __deprecated__: + value = __deprecated__[name] + warnings.warn( + f"'{name}' is deprecated. Use '{value.__name__}' instead.", + category=DeprecationWarning, + stacklevel=2, + ) + return value + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/ninja_jwt/schema_control.py b/ninja_jwt/schema_control.py new file mode 100644 index 000000000..711d07ecb --- /dev/null +++ b/ninja_jwt/schema_control.py @@ -0,0 +1,94 @@ +from typing import TYPE_CHECKING, Type + +from django.utils.module_loading import import_string + +from ninja_jwt.schema import InputSchemaMixin, TokenInputSchemaMixin + +if TYPE_CHECKING: # pragma: no cover + from ninja_jwt.settings import NinjaJWTSettings + + +class SchemaControl: + """ + A Schema Helper Class that imports Schema from configurations + """ + + def __init__(self, api_settings: "NinjaJWTSettings") -> None: + self._verify_schema = import_string(api_settings.TOKEN_VERIFY_INPUT_SCHEMA) + self.validate_type( + self._verify_schema, InputSchemaMixin, "TOKEN_VERIFY_INPUT_SCHEMA" + ) + + self._blacklist_schema = import_string( + api_settings.TOKEN_BLACKLIST_INPUT_SCHEMA + ) + self.validate_type( + self._blacklist_schema, InputSchemaMixin, "TOKEN_BLACKLIST_INPUT_SCHEMA" + ) + + self._obtain_pair_schema = import_string( + api_settings.TOKEN_OBTAIN_PAIR_INPUT_SCHEMA + ) + self.validate_type( + self._obtain_pair_schema, + TokenInputSchemaMixin, + "TOKEN_OBTAIN_PAIR_INPUT_SCHEMA", + ) + + self._obtain_pair_refresh_schema = import_string( + api_settings.TOKEN_OBTAIN_PAIR_REFRESH_INPUT_SCHEMA + ) + + self.validate_type( + self._obtain_pair_refresh_schema, + InputSchemaMixin, + "TOKEN_OBTAIN_PAIR_REFRESH_INPUT_SCHEMA", + ) + + self._obtain_sliding_schema = import_string( + api_settings.TOKEN_OBTAIN_SLIDING_INPUT_SCHEMA + ) + self.validate_type( + self._obtain_sliding_schema, + TokenInputSchemaMixin, + "TOKEN_OBTAIN_SLIDING_INPUT_SCHEMA", + ) + + self._obtain_sliding_refresh_schema = import_string( + api_settings.TOKEN_OBTAIN_SLIDING_REFRESH_INPUT_SCHEMA + ) + self.validate_type( + self._obtain_sliding_refresh_schema, + InputSchemaMixin, + "TOKEN_OBTAIN_SLIDING_REFRESH_INPUT_SCHEMA", + ) + + def validate_type( + self, schema_type: Type, sub_class: Type, settings_key: str + ) -> None: + if not issubclass(schema_type, sub_class): + raise Exception(f"{settings_key} type must inherit from `{sub_class}`") + + @property + def verify_schema(self) -> "TokenInputSchemaMixin": + return self._verify_schema + + @property + def blacklist_schema(self) -> "TokenInputSchemaMixin": + return self._blacklist_schema + + @property + def obtain_pair_schema(self) -> "TokenInputSchemaMixin": + return self._obtain_pair_schema + + @property + def obtain_pair_refresh_schema(self) -> "TokenInputSchemaMixin": + return self._obtain_pair_refresh_schema + + @property + def obtain_sliding_schema(self) -> "TokenInputSchemaMixin": + return self._obtain_sliding_schema + + @property + def obtain_sliding_refresh_schema(self) -> "TokenInputSchemaMixin": + return self._obtain_sliding_refresh_schema diff --git a/ninja_jwt/settings.py b/ninja_jwt/settings.py index 3a9cd7b56..6883c3853 100644 --- a/ninja_jwt/settings.py +++ b/ninja_jwt/settings.py @@ -3,8 +3,8 @@ from django.conf import settings from django.test.signals import setting_changed +from ninja import Schema from ninja_extra.lazy import LazyStrImport -from ninja_schema import Schema from pydantic import AnyUrl, Field, root_validator @@ -64,6 +64,25 @@ class Config: SLIDING_TOKEN_LIFETIME: timedelta = Field(timedelta(minutes=5)) SLIDING_TOKEN_REFRESH_LIFETIME: timedelta = Field(timedelta(days=1)) + TOKEN_OBTAIN_PAIR_INPUT_SCHEMA: Any = Field( + "ninja_jwt.schema.TokenObtainPairInputSchema" + ) + TOKEN_OBTAIN_PAIR_REFRESH_INPUT_SCHEMA: Any = Field( + "ninja_jwt.schema.TokenRefreshInputSchema" + ) + + TOKEN_OBTAIN_SLIDING_INPUT_SCHEMA: Any = Field( + "ninja_jwt.schema.TokenObtainSlidingInputSchema" + ) + TOKEN_OBTAIN_SLIDING_REFRESH_INPUT_SCHEMA: Any = Field( + "ninja_jwt.schema.TokenRefreshSlidingInputSchema" + ) + + TOKEN_BLACKLIST_INPUT_SCHEMA: Any = Field( + "ninja_jwt.schema.TokenBlacklistInputSchema" + ) + TOKEN_VERIFY_INPUT_SCHEMA: Any = Field("ninja_jwt.schema.TokenVerifyInputSchema") + @root_validator def validate_ninja_jwt_settings(cls, values): for item in NinjaJWT_SETTINGS_DEFAULTS.keys(): diff --git a/ninja_jwt/tokens.py b/ninja_jwt/tokens.py index e870180e4..7617dafac 100644 --- a/ninja_jwt/tokens.py +++ b/ninja_jwt/tokens.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Optional, Tuple, Type, Union +from typing import Any, Optional, Tuple from uuid import uuid4 from django.conf import settings @@ -181,7 +181,7 @@ def check_exp(self, claim: str = "exp", current_time: Optional[datetime] = None) raise TokenError(format_lazy(_("Token '{}' claim has expired"), claim)) @classmethod - def for_user(cls, user: Type[AbstractBaseUser]) -> Union["Token", Type["Token"]]: + def for_user(cls, user: AbstractBaseUser) -> "Token": """ Returns an authorization token for the given user that will be provided after authenticating the user's credentials. @@ -253,7 +253,7 @@ def blacklist(self) -> BlacklistedToken: return BlacklistedToken.objects.get_or_create(token=token) @classmethod - def for_user(cls, user: Type["AbstractBaseUser"]) -> Type[Token]: + def for_user(cls, user: "AbstractBaseUser") -> Token: """ Adds this token to the outstanding token list. """ diff --git a/pyproject.toml b/pyproject.toml index 0267200f1..1170f0093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ dependencies = [ "Django >= 2.1", "pyjwt>=1.7.1,<3", "pyjwt[crypto]", - "ninja-schema >= 0.12.8", "django-ninja-extra >= 0.14.2", ] @@ -80,6 +79,10 @@ crypto = [ "cryptography>=3.3.1", ] +dev = [ + "pre-commit" +] + doc = [ "mkdocs >=1.1.2,<2.0.0", "mkdocs-material", @@ -88,4 +91,3 @@ doc = [ "mkdocs-markdownextradata-plugin >=0.1.7,<0.3.0", "mkdocstrings" ] - diff --git a/tests/test_custom_schema.py b/tests/test_custom_schema.py new file mode 100644 index 000000000..6c67dff09 --- /dev/null +++ b/tests/test_custom_schema.py @@ -0,0 +1,459 @@ +import importlib +from datetime import timedelta +from typing import Dict, Type +from unittest.mock import patch + +import pytest +from django.contrib.auth import get_user_model +from ninja import Schema +from ninja_extra import api_controller +from ninja_extra.testing import TestClient +from pydantic import Field + +from ninja_jwt import controller +from ninja_jwt.schema import ( + TokenBlacklistInputSchema, + TokenObtainInputSchemaBase, + TokenObtainSlidingInputSchema, + TokenRefreshInputSchema, + TokenRefreshSlidingInputSchema, + TokenRefreshSlidingOutputSchema, + TokenVerifyInputSchema, +) +from ninja_jwt.schema_control import SchemaControl +from ninja_jwt.settings import api_settings +from ninja_jwt.tokens import AccessToken, RefreshToken, SlidingToken +from ninja_jwt.utils import aware_utcnow, datetime_from_epoch, datetime_to_epoch + +User = get_user_model() + + +class MyNewObtainPairTokenSchemaOutput(Schema): + refresh: str + access: str + first_name: str + last_name: str + + +class MyNewObtainTokenSlidingSchemaOutput(Schema): + token: str + first_name: str + last_name: str + + +class MyNewObtainPairSchemaInput(TokenObtainInputSchemaBase): + @classmethod + def get_response_schema(cls): + return MyNewObtainPairTokenSchemaOutput + + @classmethod + def get_token(cls, user) -> Dict: + values = {} + refresh = RefreshToken.for_user(user) + values["refresh"] = str(refresh) + values["access"] = str(refresh.access_token) + values.update( + first_name=user.first_name, last_name=user.last_name + ) # this will be needed when creating output schema + return values + + +class MyNewObtainTokenSlidingSchemaInput(TokenObtainSlidingInputSchema): + my_extra_field: str + + @classmethod + def get_response_schema(cls): + return MyNewObtainTokenSlidingSchemaOutput + + def to_response_schema(self): + return MyNewObtainTokenSlidingSchemaOutput( + first_name=self._user.first_name, + last_name=self._user.last_name, + **self.dict(exclude={"password"}) + ) + + +class MyTokenRefreshSlidingOutputSchema(TokenRefreshSlidingOutputSchema): + ninja_jwt: str = Field(default="Ninja JWT") + + +class MyTokenRefreshInputSchema(TokenRefreshInputSchema): + pass + + +class MyTokenRefreshSlidingInputSchema(TokenRefreshSlidingInputSchema): + @classmethod + def get_response_schema(cls): + return MyTokenRefreshSlidingOutputSchema + + +class MyTokenVerifyInputSchema(TokenVerifyInputSchema): + pass + + +class MyTokenBlacklistInputSchema(TokenBlacklistInputSchema): + pass + + +class InvalidTokenSchema(Schema): + whatever: str + + +@pytest.mark.django_db +class TestTokenObtainPairViewCustomSchema: + @pytest.fixture(autouse=True) + def setUp(self): + self.username = "test_user" + self.password = "test_password" + + self.user = User.objects.create_user( + username=self.username, + password=self.password, + first_name="John", + last_name="Doe", + ) + + def test_success(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_PAIR_INPUT_SCHEMA", + "tests.test_custom_schema.MyNewObtainPairSchemaInput", + ) + + importlib.reload(controller) + client = TestClient(controller.NinjaJWTDefaultController) + res = client.post( + "/pair", + json={ + User.USERNAME_FIELD: self.username, + "password": self.password, + }, + content_type="application/json", + ) + + assert res.status_code == 200 + data = res.json() + assert "access" in data + assert "refresh" in data + + assert data["first_name"] == "John" + assert data["last_name"] == "Doe" + + +@pytest.mark.django_db +class TestTokenRefreshViewCustomSchema: + @pytest.fixture(autouse=True) + def setUp(self): + self.username = "test_user" + self.password = "test_password" + + self.user = User.objects.create_user( + username=self.username, + password=self.password, + ) + + def test_refresh_works_fine(self, monkeypatch): + refresh = RefreshToken() + refresh["test_claim"] = "arst" + + # View returns 200 + now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_PAIR_REFRESH_INPUT_SCHEMA", + "tests.test_custom_schema.MyTokenRefreshInputSchema", + ) + importlib.reload(controller) + client = TestClient(controller.NinjaJWTDefaultController) + with patch("ninja_jwt.tokens.aware_utcnow") as fake_aware_utcnow: + fake_aware_utcnow.return_value = now + + res = client.post( + "/refresh", + json={"refresh": str(refresh)}, + content_type="application/json", + ) + + assert res.status_code == 200 + data = res.json() + access = AccessToken(data["access"]) + + assert refresh["test_claim"] == access["test_claim"] + assert access["exp"] == datetime_to_epoch( + now + api_settings.ACCESS_TOKEN_LIFETIME + ) + + +@pytest.mark.django_db +class TestTokenObtainSlidingViewCustomSchema: + @pytest.fixture(autouse=True) + def setUp(self): + self.username = "test_user" + self.password = "test_password" + + self.user = User.objects.create_user( + username=self.username, + password=self.password, + first_name="John", + last_name="Doe", + ) + + def test_incomplete_data(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_SLIDING_INPUT_SCHEMA", + "tests.test_custom_schema.MyNewObtainTokenSlidingSchemaInput", + ) + importlib.reload(controller) + client = TestClient(controller.NinjaJWTSlidingController) + res = client.post( + "/sliding", + json={ + User.USERNAME_FIELD: self.username, + "password": "test_password", + }, + content_type="application/json", + ) + assert res.status_code == 422 + data = res.json() + assert data == { + "detail": [ + { + "loc": ["body", "user_token", "my_extra_field"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + + def test_success(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_SLIDING_INPUT_SCHEMA", + "tests.test_custom_schema.MyNewObtainTokenSlidingSchemaInput", + ) + importlib.reload(controller) + client = TestClient(controller.NinjaJWTSlidingController) + res = client.post( + "/sliding", + json={ + User.USERNAME_FIELD: self.username, + "password": self.password, + "my_extra_field": "some_data", + }, + content_type="application/json", + ) + + assert res.status_code == 200 + data = res.json() + + assert "token" in data + assert data["first_name"] == "John" + assert data["last_name"] == "Doe" + + +@pytest.mark.django_db +class TestTokenRefreshSlidingViewCustomSchema: + @pytest.fixture(autouse=True) + def setUp(self): + self.username = "test_user" + self.password = "test_password" + + self.user = User.objects.create_user( + username=self.username, + password=self.password, + ) + + def test_it_should_update_token_exp_claim_if_everything_ok(self, monkeypatch): + now = aware_utcnow() + + token = SlidingToken() + exp = now + api_settings.SLIDING_TOKEN_LIFETIME - timedelta(seconds=1) + token.set_exp( + from_time=now, + lifetime=api_settings.SLIDING_TOKEN_LIFETIME - timedelta(seconds=1), + ) + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_SLIDING_REFRESH_INPUT_SCHEMA", + "tests.test_custom_schema.MyTokenRefreshSlidingInputSchema", + ) + importlib.reload(controller) + client = TestClient(controller.NinjaJWTSlidingController) + # View returns 200 + res = client.post( + "/sliding/refresh", + json={"token": str(token)}, + content_type="application/json", + ) + assert res.status_code == 200 + data = res.json() + assert data["ninja_jwt"] == "Ninja JWT" + # Expiration claim has moved into future + new_token = SlidingToken(data["token"]) + new_exp = datetime_from_epoch(new_token["exp"]) + + assert exp < new_exp + + +@pytest.mark.django_db +class TestTokenVerifyViewCustomSchema: + @pytest.fixture(autouse=True) + def setUp(self): + self.username = "test_user" + self.password = "test_password" + + self.user = User.objects.create_user( + username=self.username, + password=self.password, + ) + + def test_it_should_return_200_if_everything_okay(self, monkeypatch): + token = RefreshToken() + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_VERIFY_INPUT_SCHEMA", + "tests.test_custom_schema.MyTokenVerifyInputSchema", + ) + importlib.reload(controller) + client = TestClient(controller.NinjaJWTDefaultController) + res = client.post( + "/verify", json={"token": str(token)}, content_type="application/json" + ) + assert res.status_code == 200 + assert res.json() == {} + + +@pytest.mark.django_db +class TestTokenBlacklistViewCustomSchema: + @pytest.fixture(autouse=True) + def setUp(self): + self.username = "test_user" + self.password = "test_password" + + self.user = User.objects.create_user( + username=self.username, + password=self.password, + ) + + def test_it_should_return_if_everything_ok(self, monkeypatch): + refresh = RefreshToken() + refresh["test_claim"] = "arst" + + # View returns 200 + now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_BLACKLIST_INPUT_SCHEMA", + "tests.test_custom_schema.MyTokenBlacklistInputSchema", + ) + importlib.reload(controller) + client = TestClient(api_controller()(controller.TokenBlackListController)) + with patch("ninja_jwt.tokens.aware_utcnow") as fake_aware_utcnow: + fake_aware_utcnow.return_value = now + + res = client.post( + "/blacklist", + json={"refresh": str(refresh)}, + content_type="application/json", + ) + + assert res.status_code == 200 + + assert res.json() == {} + + +importlib.reload(controller) + + +class TestSchemaControlExceptions: + def test_verify_schema_exception(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_VERIFY_INPUT_SCHEMA", + "tests.test_custom_schema.InvalidTokenSchema", + ) + with pytest.raises(Exception) as ex: + SchemaControl(api_settings) + assert ( + str(ex.value) + == "TOKEN_VERIFY_INPUT_SCHEMA type must inherit from ``" + ) + + def test_blacklist_schema_exception(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_BLACKLIST_INPUT_SCHEMA", + "tests.test_custom_schema.InvalidTokenSchema", + ) + with pytest.raises(Exception) as ex: + SchemaControl(api_settings) + assert ( + str(ex.value) + == "TOKEN_BLACKLIST_INPUT_SCHEMA type must inherit from ``" + ) + + def test_obtain_pair_schema_exception(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_PAIR_INPUT_SCHEMA", + "tests.test_custom_schema.InvalidTokenSchema", + ) + with pytest.raises(Exception) as ex: + SchemaControl(api_settings) + assert ( + str(ex.value) + == "TOKEN_OBTAIN_PAIR_INPUT_SCHEMA type must inherit from ``" + ) + + def test_obtain_pair_refresh_schema_exception(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_PAIR_REFRESH_INPUT_SCHEMA", + "tests.test_custom_schema.InvalidTokenSchema", + ) + with pytest.raises(Exception) as ex: + SchemaControl(api_settings) + assert ( + str(ex.value) + == "TOKEN_OBTAIN_PAIR_REFRESH_INPUT_SCHEMA type must inherit from ``" + ) + + def test_sliding_schema_exception(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_SLIDING_INPUT_SCHEMA", + "tests.test_custom_schema.InvalidTokenSchema", + ) + with pytest.raises(Exception) as ex: + SchemaControl(api_settings) + assert ( + str(ex.value) + == "TOKEN_OBTAIN_SLIDING_INPUT_SCHEMA type must inherit from ``" + ) + + def test_sliding_refresh_schema_exception(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr( + api_settings, + "TOKEN_OBTAIN_SLIDING_REFRESH_INPUT_SCHEMA", + "tests.test_custom_schema.InvalidTokenSchema", + ) + with pytest.raises(Exception) as ex: + SchemaControl(api_settings) + assert ( + str(ex.value) + == "TOKEN_OBTAIN_SLIDING_REFRESH_INPUT_SCHEMA type must inherit from ``" + ) diff --git a/tests/test_token_blacklist.py b/tests/test_token_blacklist.py index 07fb8aafc..01c9d508c 100644 --- a/tests/test_token_blacklist.py +++ b/tests/test_token_blacklist.py @@ -6,7 +6,7 @@ from django.db.models import BigAutoField from ninja_jwt.exceptions import TokenError -from ninja_jwt.schema import TokenVerifySerializer +from ninja_jwt.schema import TokenVerifyInputSchema from ninja_jwt.settings import api_settings from ninja_jwt.token_blacklist.models import BlacklistedToken, OutstandingToken from ninja_jwt.tokens import AccessToken, RefreshToken, SlidingToken @@ -221,7 +221,7 @@ def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled refresh_token.blacklist() with pytest.raises(Exception): - TokenVerifySerializer(token=str(refresh_token)) + TokenVerifyInputSchema(token=str(refresh_token)) def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled( self, monkeypatch @@ -231,7 +231,7 @@ def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not refresh_token = RefreshToken.for_user(self.user) refresh_token.blacklist() - serializer = TokenVerifySerializer(token=str(refresh_token)) + serializer = TokenVerifyInputSchema(token=str(refresh_token)) assert serializer.token == str(refresh_token)