Skip to content

Commit

Permalink
make views synchronous again, performance is actually better this way
Browse files Browse the repository at this point in the history
  • Loading branch information
mathiasertl committed Jan 13, 2025
1 parent 2d7df2a commit cec9339
Show file tree
Hide file tree
Showing 25 changed files with 176 additions and 449 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@ WORKDIR /usr/src/django-ca/ca/
ENV DJANGO_CA_SETTINGS=conf/
ENV DJANGO_CA_SECRET_KEY_FILE=/var/lib/django-ca/certs/ca/shared/secret_key

CMD ./gunicorn.sh
CMD ./uwsgi.sh
146 changes: 63 additions & 83 deletions ca/django_ca/acme/views.py

Large diffs are not rendered by default.

23 changes: 4 additions & 19 deletions ca/django_ca/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from typing import Literal, Union

from asgiref.sync import sync_to_async
from ninja.security import HttpBasicAuth

from django.contrib.auth import get_user_model
Expand All @@ -42,28 +41,14 @@ def __init__(self, permission: str) -> None:
self.permission = permission
super().__init__()

# TODO: async implement call against warnings?

# PYLINT NOTE: documented in django-ninja docs that this can be async
async def authenticate( # pylint: disable=invalid-overridden-method
def authenticate(
self, request: HttpRequest, username: str, password: str
) -> Union[Literal[False], AbstractUser]:
user = await User.objects.aget(username=username)
user = User.objects.get(username=username)

if hasattr(user, "acheck_password"): # pragma: only django>=5.1
# Django 5.0 introduced acheckpassword().
if await user.acheck_password(password) is False:
return False
elif user.check_password(password) is False: # pragma: only django<5.1
if user.check_password(password) is False:
return False

# NOTE: ahas_perm() is introduced in Django 5.2.
if hasattr(user, "ahasperm"): # pragma: only django>5.1
# TYPEHINT NOTE: mypy will complain on currently released django==5.1.
has_perm = await user.ahas_perm(self.permission) # type: ignore[attr-defined]
else: # pragma: only django<5.2
has_perm = await sync_to_async(user.has_perm)(self.permission)

if has_perm is False:
if user.has_perm(self.permission) is False:
raise Forbidden(self.permission)
return user
58 changes: 27 additions & 31 deletions ca/django_ca/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from http import HTTPStatus

from asgiref.sync import sync_to_async
from ninja import NinjaAPI, Query
from ninja.errors import HttpError

Expand All @@ -41,6 +40,7 @@
from django_ca.api.utils import get_certificate_authority
from django_ca.models import Certificate, CertificateAuthority, CertificateOrder
from django_ca.pydantic.messages import SignCertificateMessage
from django_ca.querysets import CertificateAuthorityQuerySet, CertificateQuerySet
from django_ca.tasks import api_sign_certificate as sign_certificate_task, run_task

api = NinjaAPI(title="django-ca API", version=__version__, urls_namespace="django_ca:api")
Expand All @@ -59,15 +59,15 @@ def forbidden(request: WSGIRequest, exc: Exception) -> HttpResponse: # pylint:
summary="List available certificate authorities",
tags=["Certificate authorities"],
)
async def list_certificate_authorities(
def list_certificate_authorities(
request: WSGIRequest,
filters: CertificateAuthorityFilterSchema = Query(...), # type: ignore[type-arg] # noqa: B008
) -> list[CertificateAuthority]:
) -> CertificateAuthorityQuerySet:
"""Retrieve a list of currently usable certificate authorities."""
qs = CertificateAuthority.objects.enabled().exclude(api_enabled=False)
if filters.expired is False:
qs = qs.valid()
return [ca async for ca in qs]
return qs


@api.get(
Expand All @@ -77,9 +77,9 @@ async def list_certificate_authorities(
summary="View certificate authority",
tags=["Certificate authorities"],
)
async def view_certificate_authority(request: WSGIRequest, serial: str) -> CertificateAuthority:
def view_certificate_authority(request: WSGIRequest, serial: str) -> CertificateAuthority:
"""Retrieve details of the certificate authority with the given serial."""
return await get_certificate_authority(serial, expired=True) # You can *view* expired CAs
return get_certificate_authority(serial, expired=True) # You can *view* expired CAs


@api.put(
Expand All @@ -89,14 +89,14 @@ async def view_certificate_authority(request: WSGIRequest, serial: str) -> Certi
summary="Update certificate authority",
tags=["Certificate authorities"],
)
async def update_certificate_authority(
def update_certificate_authority(
request: WSGIRequest, serial: str, data: CertificateAuthorityUpdateSchema
) -> CertificateAuthority:
"""Update a certificate authority.
All request body fields are optional, so you can also update only individual fields.
"""
ca = await get_certificate_authority(serial, expired=True)
ca = get_certificate_authority(serial, expired=True)

# sign_certificate_policies is a django_ca.pydantic.extensions.ExtensionModel, so we can generate the
# cryptography instance directly
Expand All @@ -112,11 +112,11 @@ async def update_certificate_authority(
setattr(ca, attr, value)

try:
await sync_to_async(ca.full_clean)()
ca.full_clean()
except ValidationError as ex:
raise HttpError(HTTPStatus.BAD_REQUEST, str(ex)) from ex

await ca.asave()
ca.save()
return ca


Expand All @@ -127,9 +127,7 @@ async def update_certificate_authority(
summary="Sign a certificate",
tags=["Certificates"],
)
async def sign_certificate(
request: WSGIRequest, serial: str, data: SignCertificateMessage
) -> CertificateOrder:
def sign_certificate(request: WSGIRequest, serial: str, data: SignCertificateMessage) -> CertificateOrder:
"""Sign a certificate.
The `extensions` value is optional and allows you to add additional extensions to the certificate. Usually
Expand All @@ -141,10 +139,10 @@ async def sign_certificate(
except ValueError as ex:
raise HttpError(HTTPStatus.BAD_REQUEST, "Unable to parse CSR.") from ex

ca = await get_certificate_authority(serial)
ca = get_certificate_authority(serial)

# TYPEHINT NOTE: django-ninja sets the user as `request.auth` and mypy does not know about it
order = await CertificateOrder.objects.acreate(
order = CertificateOrder.objects.create(
certificate_authority=ca,
user=request.auth, # type: ignore[attr-defined]
)
Expand All @@ -153,9 +151,7 @@ async def sign_certificate(

# start task only after commit, see:
# https://docs.djangoproject.com/en/dev/topics/db/transactions/#django.db.transaction.on_commit
await sync_to_async(transaction.on_commit)(
lambda: run_task(sign_certificate_task, order_pk=order.pk, **parameters)
)
transaction.on_commit(lambda: run_task(sign_certificate_task, order_pk=order.pk, **parameters))

return order

Expand All @@ -167,10 +163,10 @@ async def sign_certificate(
summary="Retrieve certificate order",
tags=["Certificates"],
)
async def get_certificate_order(request: WSGIRequest, serial: str, slug: str) -> CertificateOrder:
def get_certificate_order(request: WSGIRequest, serial: str, slug: str) -> CertificateOrder:
"""Retrieve information about the certificate order identified by `slug`."""
order_queryset = CertificateOrder.objects.select_related("user", "certificate")
return await order_queryset.aget(
return order_queryset.get(
certificate_authority__serial=serial, certificate_authority__api_enabled=True, slug=slug
)

Expand All @@ -182,13 +178,13 @@ async def get_certificate_order(request: WSGIRequest, serial: str, slug: str) ->
summary="List certificates",
tags=["Certificates"],
)
async def list_certificates(
def list_certificates(
request: WSGIRequest,
serial: str,
filters: CertificateFilterSchema = Query(...), # type: ignore[type-arg] # noqa: B008
) -> list[Certificate]:
) -> CertificateQuerySet:
"""Retrieve certificates signed by the certificate authority named by `serial`."""
ca = await get_certificate_authority(serial, expired=True) # You can list certificates of expired CAs
ca = get_certificate_authority(serial, expired=True) # You can list certificates of expired CAs
qs = Certificate.objects.filter(ca=ca)

if filters.expired is False:
Expand All @@ -200,7 +196,7 @@ async def list_certificates(
if filters.profile is not None:
qs = qs.filter(profile=filters.profile)

return [cert async for cert in qs]
return qs


@api.get(
Expand All @@ -210,10 +206,10 @@ async def list_certificates(
summary="View certificate",
tags=["Certificates"],
)
async def view_certificate(request: WSGIRequest, serial: str, certificate_serial: str) -> Certificate:
def view_certificate(request: WSGIRequest, serial: str, certificate_serial: str) -> Certificate:
"""Retrieve details of the certificate with the given certificate serial."""
ca = await get_certificate_authority(serial, expired=True) # You can view certificates of expired CAs
return await Certificate.objects.aget(ca=ca, serial=certificate_serial)
ca = get_certificate_authority(serial, expired=True) # You can view certificates of expired CAs
return Certificate.objects.get(ca=ca, serial=certificate_serial)


@api.post(
Expand All @@ -223,22 +219,22 @@ async def view_certificate(request: WSGIRequest, serial: str, certificate_serial
summary="Revoke certificate",
tags=["Certificates"],
)
async def revoke_certificate(
def revoke_certificate(
request: WSGIRequest, serial: str, certificate_serial: str, revocation: RevokeCertificateSchema
) -> Certificate:
"""Revoke a certificate with the given serial.
Both `reason` and `compromised` fields are optional.
"""
ca = await get_certificate_authority(serial)
ca = get_certificate_authority(serial)
try:
cert_qs = Certificate.objects.currently_valid()
cert: Certificate = await cert_qs.aget(ca=ca, serial=certificate_serial)
cert: Certificate = cert_qs.get(ca=ca, serial=certificate_serial)
except Certificate.DoesNotExist as ex:
raise Http404(f"{certificate_serial}: Certificate not found.") from ex

if cert.revoked is True:
raise HttpError(HTTPStatus.BAD_REQUEST, "The certificate is already revoked.")

await cert.arevoke(revocation.reason, revocation.compromised)
cert.revoke(revocation.reason, revocation.compromised)
return cert
4 changes: 2 additions & 2 deletions ca/django_ca/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
User = get_user_model()


async def get_certificate_authority(serial: str, expired: bool = False) -> CertificateAuthority:
def get_certificate_authority(serial: str, expired: bool = False) -> CertificateAuthority:
"""Get a certificate authority from the given serial."""
qs = CertificateAuthority.objects.enabled().exclude(api_enabled=False)
if expired is False:
qs = qs.valid()

try:
return await qs.aget(serial=serial)
return qs.get(serial=serial)
except CertificateAuthority.DoesNotExist as ex:
raise Http404(f"{serial}: Certificate authority not found.") from ex

Expand Down
4 changes: 2 additions & 2 deletions ca/django_ca/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ def for_certificate_revocation_list(

def get_by_serial_or_cn(self, identifier: str) -> X509CertMixinTypeVar: ...

async def aget_by_serial_or_cn(self, identifier: str) -> X509CertMixinTypeVar: ...

def valid(self) -> QuerySetTypeVar: ...


Expand Down Expand Up @@ -726,6 +724,8 @@ class CertificateRevocationListManager(CertificateRevocationListManagerBase):
#
# pylint: disable=missing-function-docstring,unused-argument; just defining stubs here

def newest(self) -> Optional["CertificateRevocationList"]: ...

def reasons(
self, only_some_reasons: Optional[frozenset[x509.ReasonFlags]]
) -> "CertificateRevocationListQuerySet": ...
Expand Down
Loading

0 comments on commit cec9339

Please sign in to comment.