Skip to content

Commit

Permalink
migrate API endpoints to async
Browse files Browse the repository at this point in the history
  • Loading branch information
mathiasertl committed Jan 1, 2025
1 parent dbd9188 commit f888bb6
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 41 deletions.
25 changes: 21 additions & 4 deletions ca/django_ca/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing import Literal, Union

from asgiref.sync import sync_to_async
from ninja.security import HttpBasicAuth

from django.contrib.auth import get_user_model
Expand All @@ -41,12 +42,28 @@ def __init__(self, permission: str) -> None:
self.permission = permission
super().__init__()

def authenticate(
# TODO: async implement call against warnings?

# PYLINT NOTE: documented in django-ninja docs that this can be async
async def authenticate( # pylint: disable=invalid-overridden-method
self, request: HttpRequest, username: str, password: str
) -> Union[Literal[False], AbstractUser]:
user = User.objects.get(username=username)
if user.check_password(password) is False:
user = await User.objects.aget(username=username)

if hasattr(user, "acheck_password"): # pragma: only django>=5.1
# Django 5.0 introduced acheckpassword().
if await user.acheck_password(password) is False:
return False
elif user.check_password(password) is False: # pragma: only django<5.1
return False
if user.has_perm(self.permission) is False:

# NOTE: ahas_perm() is introduced in Django 5.2.
if hasattr(user, "ahasperm"): # pragma: only django>5.1
# TYPEHINT NOTE: mypy will complain on currently released django==5.1.
has_perm = await user.ahas_perm(self.permission) # type: ignore[attr-defined]
else: # pragma: only django<5.2
has_perm = await sync_to_async(user.has_perm)(self.permission)

if has_perm is False:
raise Forbidden(self.permission)
return user
60 changes: 33 additions & 27 deletions ca/django_ca/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from http import HTTPStatus

from asgiref.sync import sync_to_async
from ninja import NinjaAPI, Query
from ninja.errors import HttpError

Expand All @@ -40,7 +41,6 @@
from django_ca.api.utils import get_certificate_authority
from django_ca.models import Certificate, CertificateAuthority, CertificateOrder
from django_ca.pydantic.messages import SignCertificateMessage
from django_ca.querysets import CertificateAuthorityQuerySet, CertificateQuerySet
from django_ca.tasks import api_sign_certificate as sign_certificate_task, run_task

api = NinjaAPI(title="django-ca API", version=__version__, urls_namespace="django_ca:api")
Expand All @@ -59,15 +59,15 @@ def forbidden(request: WSGIRequest, exc: Exception) -> HttpResponse: # pylint:
summary="List available certificate authorities",
tags=["Certificate authorities"],
)
def list_certificate_authorities(
async def list_certificate_authorities(
request: WSGIRequest,
filters: CertificateAuthorityFilterSchema = Query(...), # type: ignore[type-arg] # noqa: B008
) -> CertificateAuthorityQuerySet:
) -> list[CertificateAuthority]:
"""Retrieve a list of currently usable certificate authorities."""
qs = CertificateAuthority.objects.enabled().exclude(api_enabled=False)
if filters.expired is False:
qs = qs.valid()
return qs
return [ca async for ca in qs]


@api.get(
Expand All @@ -77,9 +77,9 @@ def list_certificate_authorities(
summary="View certificate authority",
tags=["Certificate authorities"],
)
def view_certificate_authority(request: WSGIRequest, serial: str) -> CertificateAuthority:
async def view_certificate_authority(request: WSGIRequest, serial: str) -> CertificateAuthority:
"""Retrieve details of the certificate authority with the given serial."""
return get_certificate_authority(serial, expired=True) # You can *view* expired CAs
return await get_certificate_authority(serial, expired=True) # You can *view* expired CAs


@api.put(
Expand All @@ -89,14 +89,14 @@ def view_certificate_authority(request: WSGIRequest, serial: str) -> Certificate
summary="Update certificate authority",
tags=["Certificate authorities"],
)
def update_certificate_authority(
async def update_certificate_authority(
request: WSGIRequest, serial: str, data: CertificateAuthorityUpdateSchema
) -> CertificateAuthority:
"""Update a certificate authority.
All request body fields are optional, so you can also update only individual fields.
"""
ca = get_certificate_authority(serial, expired=True)
ca = await get_certificate_authority(serial, expired=True)

# sign_certificate_policies is a django_ca.pydantic.extensions.ExtensionModel, so we can generate the
# cryptography instance directly
Expand All @@ -112,11 +112,11 @@ def update_certificate_authority(
setattr(ca, attr, value)

try:
ca.full_clean()
await sync_to_async(ca.full_clean)()
except ValidationError as ex:
raise HttpError(HTTPStatus.BAD_REQUEST, str(ex)) from ex

ca.save()
await ca.asave()
return ca


Expand All @@ -127,7 +127,9 @@ def update_certificate_authority(
summary="Sign a certificate",
tags=["Certificates"],
)
def sign_certificate(request: WSGIRequest, serial: str, data: SignCertificateMessage) -> CertificateOrder:
async def sign_certificate(
request: WSGIRequest, serial: str, data: SignCertificateMessage
) -> CertificateOrder:
"""Sign a certificate.
The `extensions` value is optional and allows you to add additional extensions to the certificate. Usually
Expand All @@ -139,10 +141,10 @@ def sign_certificate(request: WSGIRequest, serial: str, data: SignCertificateMes
except ValueError as ex:
raise HttpError(HTTPStatus.BAD_REQUEST, "Unable to parse CSR.") from ex

ca = get_certificate_authority(serial)
ca = await get_certificate_authority(serial)

# TYPEHINT NOTE: django-ninja sets the user as `request.auth` and mypy does not know about it
order = CertificateOrder.objects.create(
order = await CertificateOrder.objects.acreate(
certificate_authority=ca,
user=request.auth, # type: ignore[attr-defined]
)
Expand All @@ -151,7 +153,9 @@ def sign_certificate(request: WSGIRequest, serial: str, data: SignCertificateMes

# start task only after commit, see:
# https://docs.djangoproject.com/en/dev/topics/db/transactions/#django.db.transaction.on_commit
transaction.on_commit(lambda: run_task(sign_certificate_task, order_pk=order.pk, **parameters))
await sync_to_async(transaction.on_commit)(
lambda: run_task(sign_certificate_task, order_pk=order.pk, **parameters)
)

return order

Expand All @@ -163,9 +167,10 @@ def sign_certificate(request: WSGIRequest, serial: str, data: SignCertificateMes
summary="Retrieve certificate order",
tags=["Certificates"],
)
def get_certificate_order(request: WSGIRequest, serial: str, slug: str) -> CertificateOrder:
async def get_certificate_order(request: WSGIRequest, serial: str, slug: str) -> CertificateOrder:
"""Retrieve information about the certificate order identified by `slug`."""
return CertificateOrder.objects.get(
order_queryset = CertificateOrder.objects.select_related("user", "certificate")
return await order_queryset.aget(
certificate_authority__serial=serial, certificate_authority__api_enabled=True, slug=slug
)

Expand All @@ -177,13 +182,13 @@ def get_certificate_order(request: WSGIRequest, serial: str, slug: str) -> Certi
summary="List certificates",
tags=["Certificates"],
)
def list_certificates(
async def list_certificates(
request: WSGIRequest,
serial: str,
filters: CertificateFilterSchema = Query(...), # type: ignore[type-arg] # noqa: B008
) -> CertificateQuerySet:
) -> list[Certificate]:
"""Retrieve certificates signed by the certificate authority named by `serial`."""
ca = get_certificate_authority(serial, expired=True) # You can list certificates of expired CAs
ca = await get_certificate_authority(serial, expired=True) # You can list certificates of expired CAs
qs = Certificate.objects.filter(ca=ca)

if filters.expired is False:
Expand All @@ -195,7 +200,7 @@ def list_certificates(
if filters.profile is not None:
qs = qs.filter(profile=filters.profile)

return qs
return [cert async for cert in qs]


@api.get(
Expand All @@ -205,10 +210,10 @@ def list_certificates(
summary="View certificate",
tags=["Certificates"],
)
def view_certificate(request: WSGIRequest, serial: str, certificate_serial: str) -> Certificate:
async def view_certificate(request: WSGIRequest, serial: str, certificate_serial: str) -> Certificate:
"""Retrieve details of the certificate with the given certificate serial."""
ca = get_certificate_authority(serial, expired=True) # You can view certificates of expired CAs
return Certificate.objects.get(ca=ca, serial=certificate_serial)
ca = await get_certificate_authority(serial, expired=True) # You can view certificates of expired CAs
return await Certificate.objects.aget(ca=ca, serial=certificate_serial)


@api.post(
Expand All @@ -218,21 +223,22 @@ def view_certificate(request: WSGIRequest, serial: str, certificate_serial: str)
summary="Revoke certificate",
tags=["Certificates"],
)
def revoke_certificate(
async def revoke_certificate(
request: WSGIRequest, serial: str, certificate_serial: str, revocation: RevokeCertificateSchema
) -> Certificate:
"""Revoke a certificate with the given serial.
Both `reason` and `compromised` fields are optional.
"""
ca = get_certificate_authority(serial)
ca = await get_certificate_authority(serial)
try:
cert = Certificate.objects.currently_valid().get(ca=ca, serial=certificate_serial)
cert_qs = Certificate.objects.currently_valid()
cert: Certificate = await cert_qs.aget(ca=ca, serial=certificate_serial)
except Certificate.DoesNotExist as ex:
raise Http404(f"{certificate_serial}: Certificate not found.") from ex

if cert.revoked is True:
raise HttpError(HTTPStatus.BAD_REQUEST, "The certificate is already revoked.")

cert.revoke(revocation.reason, revocation.compromised)
await cert.arevoke(revocation.reason, revocation.compromised)
return cert
4 changes: 2 additions & 2 deletions ca/django_ca/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
User = get_user_model()


def get_certificate_authority(serial: str, expired: bool = False) -> CertificateAuthority:
async def get_certificate_authority(serial: str, expired: bool = False) -> CertificateAuthority:
"""Get a certificate authority from the given serial."""
qs = CertificateAuthority.objects.enabled().exclude(api_enabled=False)
if expired is False:
qs = qs.valid()

try:
return qs.get(serial=serial)
return await qs.aget(serial=serial)
except CertificateAuthority.DoesNotExist as ex:
raise Http404(f"{serial}: Certificate authority not found.") from ex

Expand Down
31 changes: 23 additions & 8 deletions ca/django_ca/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import re
import typing
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone as tz
from typing import Literal, Optional, Union

Expand Down Expand Up @@ -450,6 +451,20 @@ def get_revocation(self) -> x509.RevokedCertificate:

return revoked_cert.build()

@contextmanager
def _revoke(
self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None
) -> Iterator[None]:
pre_revoke_cert.send(sender=self.__class__, cert=self, reason=reason)

self.revoked = True
self.revoked_date = timezone.now()
self.revoked_reason = reason.name
self.compromised = compromised
yield

post_revoke_cert.send(sender=self.__class__, cert=self)

def revoke(
self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None
) -> None:
Expand All @@ -464,15 +479,15 @@ def revoke(
compromised : datetime, optional
When this certificate was compromised.
"""
pre_revoke_cert.send(sender=self.__class__, cert=self, reason=reason)

self.revoked = True
self.revoked_date = timezone.now()
self.revoked_reason = reason.name
self.compromised = compromised
self.save()
with self._revoke(reason, compromised):
self.save()

post_revoke_cert.send(sender=self.__class__, cert=self)
async def arevoke(
self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None
) -> None:
"""Revoke the current certificate (async version)."""
with self._revoke(reason, compromised):
await self.asave()


class CertificateAuthority(X509CertMixin): # type: ignore[django-manager-missing]
Expand Down

0 comments on commit f888bb6

Please sign in to comment.