Skip to content

Commit

Permalink
Pydantic v2 upgrade (#45)
Browse files Browse the repository at this point in the history
* refactoring

* fixed failing tests

* removed extra attribute from input schema

* fixed failing test

* NinjaExtra Upgrade: Bumped up Ninja Extra version to 0.20.0

* removed ninja-schema from ci test
  • Loading branch information
eadwinCode authored Nov 19, 2023
1 parent 7ca48bd commit 15c06f0
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install core
run: pip install "Django${{ matrix.django-version }}"
- name: Install tests
run: pip install pytest pytest-asyncio ninja-schema pytest-django django-ninja-extra python-jose==3.3.0 pyjwt[crypto]
run: pip install pytest pytest-asyncio pytest-django django-ninja-extra python-jose==3.3.0 pyjwt[crypto]
- name: Test
run: pytest
codestyle:
Expand Down
192 changes: 113 additions & 79 deletions ninja_jwt/schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import typing
import warnings
from typing import Any, Callable, Dict, Optional, Type, cast
from typing import Any, Dict, Optional, Type, cast

from django.conf import settings
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.models import AbstractUser, update_last_login
from django.utils.translation import gettext_lazy as _
from ninja import ModelSchema, Schema
from pydantic import root_validator
from ninja.schema import DjangoGetter
from pydantic import model_validator

import ninja_jwt.exceptions as exceptions
from ninja_jwt.utils import token_error
Expand All @@ -27,15 +29,13 @@ class Config:


class InputSchemaMixin:
dict: Callable

@classmethod
def get_response_schema(cls) -> Type[Schema]:
raise NotImplementedError("Must implement `get_response_schema`")

def to_response_schema(self):
_schema_type = self.get_response_schema()
return _schema_type(**self.dict())
return _schema_type(**self.model_dump())


class TokenInputSchemaMixin(InputSchemaMixin):
Expand All @@ -52,7 +52,7 @@ def check_user_authentication_rule(self) -> None:
)

@classmethod
def validate_values(cls, values: Dict) -> dict:
def validate_values(cls, values: Dict) -> Dict:
if user_name_field not in values and "password" not in values:
raise exceptions.ValidationError(
{
Expand Down Expand Up @@ -96,14 +96,19 @@ def get_token(cls, user: AbstractUser) -> Dict:

class TokenObtainInputSchemaBase(ModelSchema, TokenInputSchemaMixin):
class Config:
# extra = "allow"
model = get_user_model()
model_fields = ["password", user_name_field]

@root_validator(pre=True)
def validate_inputs(cls, values: Dict) -> dict:
return cls.validate_values(values)
@model_validator(mode="before")
def validate_inputs(cls, values: DjangoGetter) -> DjangoGetter:
input_values = values._obj
if isinstance(input_values, dict):
values._obj.update(cls.validate_values(input_values))
return values
return values

@root_validator
@model_validator(mode="after")
def post_validate(cls, values: Dict) -> dict:
return cls.post_validate_schema(values)

Expand All @@ -121,16 +126,23 @@ def post_validate_schema(cls, values: Dict) -> dict:
if not isinstance(data, dict):
raise Exception("`get_token` must return a `typing.Dict` type.")

values.update(data)
# a workaround for extra attributes since adding extra=allow in modelconfig adds `addition_props`
# field to the schema
values.__dict__.update(token_data=data)

if api_settings.UPDATE_LAST_LOGIN:
update_last_login(None, cls._user)

return values

def get_response_schema_init_kwargs(self) -> dict:
return dict(
self.dict(exclude={"password"}), **self.__dict__.get("token_data", {})
)

def to_response_schema(self):
_schema_type = self.get_response_schema()
return _schema_type(**self.dict(exclude={"password"}))
return _schema_type(**self.get_response_schema_init_kwargs())


class TokenObtainPairOutputSchema(AuthUserSchema):
Expand Down Expand Up @@ -173,10 +185,13 @@ def get_token(cls, user: AbstractUser) -> Dict:
class TokenRefreshInputSchema(Schema, InputSchemaMixin):
refresh: str

@root_validator
def validate_schema(cls, values: Dict) -> dict:
if not values.get("refresh"):
raise exceptions.ValidationError({"refresh": "token is required"})
@model_validator(mode="before")
def validate_schema(cls, values: DjangoGetter) -> dict:
values = values._obj

if isinstance(values, dict):
if not values.get("refresh"):
raise exceptions.ValidationError({"refresh": "token is required"})
return values

@classmethod
Expand All @@ -188,42 +203,50 @@ class TokenRefreshOutputSchema(Schema):
refresh: str
access: Optional[str]

@root_validator
@model_validator(mode="before")
@token_error
def validate_schema(cls, values: Dict) -> dict:
if not values.get("refresh"):
raise exceptions.ValidationError({"refresh": "refresh token is required"})

refresh = RefreshToken(values["refresh"])

data = {"access": str(refresh.access_token)}

if api_settings.ROTATE_REFRESH_TOKENS:
if api_settings.BLACKLIST_AFTER_ROTATION:
try:
# Attempt to blacklist the given refresh token
refresh.blacklist()
except AttributeError:
# If blacklist app not installed, `blacklist` method will
# not be present
pass

refresh.set_jti()
refresh.set_exp()
refresh.set_iat()

data["refresh"] = str(refresh)
values.update(data)
def validate_schema(cls, values: DjangoGetter) -> typing.Any:
values = values._obj

if isinstance(values, dict):
if not values.get("refresh"):
raise exceptions.ValidationError(
{"refresh": "refresh token is required"}
)

refresh = RefreshToken(values["refresh"])

data = {"access": str(refresh.access_token)}

if api_settings.ROTATE_REFRESH_TOKENS:
if api_settings.BLACKLIST_AFTER_ROTATION:
try:
# Attempt to blacklist the given refresh token
refresh.blacklist()
except AttributeError:
# If blacklist app not installed, `blacklist` method will
# not be present
pass

refresh.set_jti()
refresh.set_exp()
refresh.set_iat()

data["refresh"] = str(refresh)
values.update(data)
return values


class TokenRefreshSlidingInputSchema(Schema, InputSchemaMixin):
token: str

@root_validator
def validate_schema(cls, values: Dict) -> dict:
if not values.get("token"):
raise exceptions.ValidationError({"token": "token is required"})
@model_validator(mode="before")
def validate_schema(cls, values: DjangoGetter) -> dict:
values = values._obj

if isinstance(values, dict):
if not values.get("token"):
raise exceptions.ValidationError({"token": "token is required"})
return values

@classmethod
Expand All @@ -234,42 +257,48 @@ def get_response_schema(cls) -> Type[Schema]:
class TokenRefreshSlidingOutputSchema(Schema):
token: str

@root_validator
@model_validator(mode="before")
@token_error
def validate_schema(cls, values: Dict) -> dict:
if not values.get("token"):
raise exceptions.ValidationError({"token": "token is required"})
def validate_schema(cls, values: DjangoGetter) -> dict:
values = values._obj

if isinstance(values, dict):
if not values.get("token"):
raise exceptions.ValidationError({"token": "token is required"})

token = SlidingToken(values["token"])
token = SlidingToken(values["token"])

# Check that the timestamp in the "refresh_exp" claim has not
# passed
token.check_exp(api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM)
# Check that the timestamp in the "refresh_exp" claim has not
# passed
token.check_exp(api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM)

# Update the "exp" and "iat" claims
token.set_exp()
token.set_iat()
values.update({"token": str(token)})
# Update the "exp" and "iat" claims
token.set_exp()
token.set_iat()
values.update({"token": str(token)})
return values


class TokenVerifyInputSchema(Schema, InputSchemaMixin):
token: str

@root_validator
@model_validator(mode="before")
@token_error
def validate_schema(cls, values: Dict) -> dict:
if not values.get("token"):
raise exceptions.ValidationError({"token": "token is required"})
token = UntypedToken(values["token"])

if (
api_settings.BLACKLIST_AFTER_ROTATION
and "ninja_jwt.token_blacklist" in settings.INSTALLED_APPS
):
jti = token.get(api_settings.JTI_CLAIM)
if BlacklistedToken.objects.filter(token__jti=jti).exists():
raise exceptions.ValidationError("Token is blacklisted")
def validate_schema(cls, values: DjangoGetter) -> Dict:
values = values._obj

if isinstance(values, dict):
if not values.get("token"):
raise exceptions.ValidationError({"token": "token is required"})
token = UntypedToken(values["token"])

if (
api_settings.BLACKLIST_AFTER_ROTATION
and "ninja_jwt.token_blacklist" in settings.INSTALLED_APPS
):
jti = token.get(api_settings.JTI_CLAIM)
if BlacklistedToken.objects.filter(token__jti=jti).exists():
raise exceptions.ValidationError("Token is blacklisted")

return values

Expand All @@ -284,16 +313,21 @@ def to_response_schema(self):
class TokenBlacklistInputSchema(Schema, InputSchemaMixin):
refresh: str

@root_validator
@model_validator(mode="before")
@token_error
def validate_schema(cls, values: Dict) -> dict:
if not values.get("refresh"):
raise exceptions.ValidationError({"refresh": "refresh token is required"})
refresh = RefreshToken(values["refresh"])
try:
refresh.blacklist()
except AttributeError:
pass
def validate_schema(cls, values: DjangoGetter) -> dict:
values = values._obj

if isinstance(values, dict):
if not values.get("refresh"):
raise exceptions.ValidationError(
{"refresh": "refresh token is required"}
)
refresh = RefreshToken(values["refresh"])
try:
refresh.blacklist()
except AttributeError:
pass
return values

@classmethod
Expand Down
5 changes: 2 additions & 3 deletions ninja_jwt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from django.conf import settings
from django.test.signals import setting_changed
from ninja import Schema
from ninja_extra.lazy import LazyStrImport
from pydantic import AnyUrl, Field, root_validator
from pydantic.v1 import AnyUrl, BaseModel, Field, root_validator


class NinjaJWTUserDefinedSettingsMapper:
Expand All @@ -28,7 +27,7 @@ def __init__(self, data: dict) -> None:
)


class NinjaJWTSettings(Schema):
class NinjaJWTSettings(BaseModel):
class Config:
orm_mode = True
validate_assignment = True
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies = [
"Django >= 2.1",
"pyjwt>=1.7.1,<3",
"pyjwt[crypto]",
"django-ninja-extra >= 0.14.2",
"django-ninja-extra >= 0.20.0",
]


Expand All @@ -66,8 +66,8 @@ test = [
"pytest-cov",
"pytest-django",
"pytest-asyncio",
"ruff ==0.1.3",
"black == 23.10.1",
"black ==23.10.1",
"ruff ==0.1.4",
"django-stubs",
"python-jose==3.3.0",
"click==8.1.7"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_custom_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def to_response_schema(self):
return MyNewObtainTokenSlidingSchemaOutput(
first_name=self._user.first_name,
last_name=self._user.last_name,
**self.dict(exclude={"password"}),
**self.get_response_schema_init_kwargs(),
)


Expand Down Expand Up @@ -226,8 +226,8 @@ def test_incomplete_data(self, monkeypatch):
"detail": [
{
"loc": ["body", "user_token", "my_extra_field"],
"msg": "field required",
"type": "value_error.missing",
"msg": "Field required",
"type": "missing",
}
]
}
Expand Down

0 comments on commit 15c06f0

Please sign in to comment.