From 39a6da98d823131bd71dc0d6cb9a64cb0b4b53c5 Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Mon, 30 Dec 2024 11:35:00 +0100 Subject: [PATCH] migrate ACME views to async --- ca/django_ca/acme/views.py | 258 +++++++++++++++++++++++-------------- ca/django_ca/managers.py | 27 +++- ca/django_ca/models.py | 12 ++ ca/django_ca/querysets.py | 57 +++++++- 4 files changed, 248 insertions(+), 106 deletions(-) diff --git a/ca/django_ca/acme/views.py b/ca/django_ca/acme/views.py index 8e5fd7ade..fd94b7229 100644 --- a/ca/django_ca/acme/views.py +++ b/ca/django_ca/acme/views.py @@ -31,6 +31,7 @@ import acme.jws import josepy as jose from acme import messages +from asgiref.sync import sync_to_async from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization @@ -79,6 +80,7 @@ CertificateAuthority, ) from django_ca.pydantic.validators import email_validator +from django_ca.querysets import AcmeAccountQuerySet from django_ca.tasks import acme_issue_certificate, acme_validate_challenge, run_task from django_ca.utils import check_name, int_to_hex @@ -144,7 +146,7 @@ class AcmeDirectory(View): def _url(self, request: HttpRequest, name: str, ca: CertificateAuthority) -> str: return request.build_absolute_uri(reverse(f"django_ca:{name}", kwargs={"serial": ca.serial})) - def get(self, request: HttpRequest, serial: Optional[str] = None) -> HttpResponse: + async def get(self, request: HttpRequest, serial: Optional[str] = None) -> HttpResponse: # pylint: disable=missing-function-docstring; standard Django view function if not model_settings.CA_ENABLE_ACME: raise Http404("Page not found.") @@ -152,13 +154,13 @@ def get(self, request: HttpRequest, serial: Optional[str] = None) -> HttpRespons if serial is None: try: # NOTE: default() already calls usable() - ca = CertificateAuthority.objects.acme().default() + ca = await CertificateAuthority.objects.acme().adefault() except ImproperlyConfigured: return AcmeResponseNotFound(message="No (usable) default CA configured.") else: try: # NOTE: Serial is already sanitized by URL converter - ca = CertificateAuthority.objects.acme().usable().get(serial=serial) + ca = await CertificateAuthority.objects.acme().usable().aget(serial=serial) except CertificateAuthority.DoesNotExist: return AcmeResponseNotFound(message=f"{serial}: CA not found.") @@ -204,17 +206,17 @@ def get_cache_key(self, nonce: str) -> str: """Get the cache key for the given request and nonce.""" return f"acme-nonce-{self.kwargs['serial']}-{nonce}" - def get_nonce(self) -> str: + async def get_nonce(self) -> str: """Get a random Nonce and add it to the cache.""" data = secrets.token_bytes(self.nonce_length) nonce = jose.json_util.encode_b64jose(data) - cache.set(self.get_cache_key(nonce), 0) + await cache.aset(self.get_cache_key(nonce), 0) return nonce - def validate_nonce(self, nonce: str) -> bool: + async def validate_nonce(self, nonce: str) -> bool: """Validate that the given nonce was issued and was not used before.""" try: - count = cache.incr(self.get_cache_key(nonce)) + count = await cache.aincr(self.get_cache_key(nonce)) except ValueError: # raised if cache_key is not set return False @@ -236,7 +238,7 @@ class AcmeBaseView(AcmeGetNonceViewMixin, View, metaclass=abc.ABCMeta): jws: acme.jws.JWS @abc.abstractmethod - def process_acme_request(self, slug: Optional[str]) -> AcmeResponse: + async def process_acme_request(self, slug: Optional[str]) -> AcmeResponse: """Abstract method expected to implement processing a message. The `slug` argument is the URL slug that identifies an ACME object and is None for requests that @@ -297,9 +299,12 @@ def set_link_relations(self, response: "HttpResponseBase", **kwargs: str) -> Non # with open(prepared_path, 'w') as stream: # json.dump(prepared_data, stream, indent=4) - # NOINSPECTION NOTE: It's okay to be more specific here - # noinspection PyMethodOverriding - def dispatch(self, request: HttpRequest, serial: str, slug: Optional[str] = None) -> "HttpResponseBase": + # TYPEHINT NOTE: Django does not officially support dispatch() being an async method, but in async views, + # it returns a coroutine, just like an async view would. Thus declaring it async should not make much + # difference. + async def dispatch( # type: ignore[override] # pylint: disable=invalid-overridden-method + self, request: HttpRequest, serial: str, slug: Optional[str] = None + ) -> "HttpResponseBase": if not model_settings.CA_ENABLE_ACME: raise Http404("Page not found.") @@ -310,7 +315,12 @@ def dispatch(self, request: HttpRequest, serial: str, slug: Optional[str] = None raise ImproperlyConfigured("View expects a str for a slug") try: - response = super().dispatch(request, serial=serial, slug=slug) + # TYPEHINT NOTE: super().dispatch() is not an async method but still returns a coroutine, because + # post() is async. We await it and cast it to AcmeResponse (since post already returns that). + response = cast( + AcmeResponse, + await super().dispatch(request, serial=serial, slug=slug), # type:ignore[misc] + ) except AcmeException as ex: response = ex.get_response() except Exception as ex: # pylint: disable=broad-except @@ -325,10 +335,10 @@ def dispatch(self, request: HttpRequest, serial: str, slug: Optional[str] = None # An ACME server provides nonces to clients using the HTTP Replay-Nonce header field, as specified in # Section 6.5.1. The server MUST include a Replay-Nonce header field in every successful response to # a POST request and SHOULD provide it in error responses as well. - response["replay-nonce"] = self.get_nonce() + response["replay-nonce"] = await self.get_nonce() return response - def post( # noqa: PLR0911 + async def post( # noqa: PLR0911 self, request: HttpRequest, serial: str, slug: Optional[str] = None ) -> AcmeResponse: # pylint: disable=missing-function-docstring; standard Django view function @@ -337,10 +347,10 @@ def post( # noqa: PLR0911 # TODO: RFC 8555, 6.2 has a nice list of things to check here that we don't yet fully cover if request.content_type != "application/jose+json": # RFC 8555, 6.2: - # "Because client requests in ACME carry JWS objects in the Flattened JSON Serialization, they + # Because client requests in ACME carry JWS objects in the Flattened JSON Serialization, they # must have the Content-Type header field set to "application/jose+json". If a request does not # meet this requirement, then the server MUST return a response with status code 415 (Unsupported - # Media Type)." + # Media Type). return AcmeResponseUnsupportedMediaType() # self.prepared['body'] = json.loads(request.body.decode('utf-8')) @@ -358,7 +368,7 @@ def post( # noqa: PLR0911 # Get certificate authority for this request try: - self.ca = CertificateAuthority.objects.acme().usable().get(serial=serial) + self.ca = await CertificateAuthority.objects.acme().usable().aget(serial=serial) except CertificateAuthority.DoesNotExist: return AcmeResponseNotFound(message="The requested CA cannot be found.") @@ -373,7 +383,8 @@ def post( # noqa: PLR0911 # combined.kid is a full URL pointing to the account. try: - account = AcmeAccount.objects.viewable().get(ca=self.ca, kid=combined.kid) + account_qs = AcmeAccount.objects.viewable().url() + account = await account_qs.aget(ca=self.ca, kid=combined.kid) except AcmeAccount.DoesNotExist: return AcmeResponseUnauthorized(message="Account not found.") @@ -411,7 +422,9 @@ def post( # noqa: PLR0911 return AcmeResponseMalformed(message="JWS signature invalid.") # self.prepared['nonce'] = jose.encode_b64jose(combined.nonce) - if combined.nonce is None or not self.validate_nonce(jose.json_util.encode_b64jose(combined.nonce)): + if combined.nonce is None or not await self.validate_nonce( + jose.json_util.encode_b64jose(combined.nonce) + ): # ... "nonce" return AcmeResponseBadNonce() @@ -421,7 +434,7 @@ def post( # noqa: PLR0911 # match, then the server MUST reject the request as unauthorized." return AcmeResponseUnauthorized(message="URL does not match.") - return self.process_acme_request(slug=slug) + return await self.process_acme_request(slug=slug) class AcmePostAsGetView(AcmeBaseView, metaclass=abc.ABCMeta): @@ -430,7 +443,7 @@ class AcmePostAsGetView(AcmeBaseView, metaclass=abc.ABCMeta): ignore_body = False # True if we want to ignore the message body @abc.abstractmethod - def acme_request(self, slug: str) -> AcmeResponse: + async def acme_request(self, slug: str) -> AcmeResponse: """Abstract method to process an ACME post-as-get request. Actual view subclasses are expected to implement this function. @@ -439,13 +452,13 @@ def acme_request(self, slug: str) -> AcmeResponse: contain no information. """ - def process_acme_request(self, slug: Optional[str]) -> AcmeResponse: + async def process_acme_request(self, slug: Optional[str]) -> AcmeResponse: if self.ignore_body is False and self.jws.payload != b"": return AcmeResponseMalformed(message="Non-empty payload in get-as-post request.") if slug is None: # pragma: no cover; just a safety measure return AcmeResponseError(message="PostAsGet view called with slug.") - return self.acme_request(slug=slug) + return await self.acme_request(slug=slug) class AcmeMessageBaseView(AcmeBaseView, Generic[MessageTypeVar], metaclass=abc.ABCMeta): @@ -454,20 +467,20 @@ class AcmeMessageBaseView(AcmeBaseView, Generic[MessageTypeVar], metaclass=abc.A message_cls: type[MessageTypeVar] @abc.abstractmethod - def acme_request(self, message: MessageTypeVar, slug: Optional[str]) -> AcmeResponse: + async def acme_request(self, message: MessageTypeVar, slug: Optional[str]) -> AcmeResponse: """Process ACME request. Actual view subclasses are expected to implement this function. """ - def process_acme_request(self, slug: Optional[str]) -> AcmeResponse: + async def process_acme_request(self, slug: Optional[str]) -> AcmeResponse: try: message = self.message_cls.json_loads(self.jws.payload) log.debug("ACME message: %s", message) except jose.errors.DeserializationError as e: return AcmeResponseMalformedPayload(message=", ".join(e.args)) - return self.acme_request(message, slug) + return await self.acme_request(message, slug) class AcmeNewNonceView(AcmeGetNonceViewMixin, View): @@ -491,29 +504,34 @@ def dispatch(self, request: HttpRequest, serial: str) -> "HttpResponseBase": if not isinstance(serial, str): # pragma: no cover raise ImproperlyConfigured("View expects a str for a serial") - response = super().dispatch(request, serial) - response["replay-nonce"] = self.get_nonce() + return super().dispatch(request, serial) - # RFC 8555, section 7.2: - # - # The server MUST include a Cache-Control header field with the "no-store" directive in responses - response["cache-control"] = "no-store" - return response + async def get_headers(self) -> dict[str, str]: + """Get headers used in responses to both HEAD and GET requests.""" + return { + "replay-nonce": await self.get_nonce(), + # RFC 8555, section 7.2: + # + # The server MUST include a Cache-Control header field with the "no-store" directive in + # responses + "cache-control": "no-store", + } - def head(self, request: HttpRequest, serial: str) -> HttpResponse: + async def head(self, request: HttpRequest, serial: str) -> HttpResponse: """Get a new Nonce with a HEAD request.""" # pylint: disable=method-hidden; false positive - View.setup() sets head as property if not defined # pylint: disable=unused-argument; false positive - really used by AcmeGetNonceViewMixin - return HttpResponse() + return HttpResponse(headers=await self.get_headers()) - def get(self, request: HttpRequest, serial: str) -> HttpResponse: + async def get(self, request: HttpRequest, serial: str) -> HttpResponse: """Get a new Nonce with a GET request. Note that certbot always does a HEAD request, but RFC 8555, section 7.2 mandates support for GET requests. """ # pylint: disable=unused-argument; false positive - really used by AcmeGetNonceViewMixin - return HttpResponse(status=HTTPStatus.NO_CONTENT) # 204, unlike HEAD, which has 200 + headers = await self.get_headers() + return HttpResponse(status=HTTPStatus.NO_CONTENT, headers=headers) # 204, unlike HEAD, which has 200 class AcmeNewAccountView(ContactValidationMixin, AcmeMessageBaseView[messages.Registration]): @@ -527,7 +545,7 @@ class AcmeNewAccountView(ContactValidationMixin, AcmeMessageBaseView[messages.Re message_cls = messages.Registration requires_key = True - def acme_request(self, message: messages.Registration, slug: Optional[str]) -> AcmeResponseAccount: + async def acme_request(self, message: messages.Registration, slug: Optional[str]) -> AcmeResponseAccount: """Process ACME request.""" pem = ( self.jwk.key.public_bytes( @@ -538,6 +556,9 @@ def acme_request(self, message: messages.Registration, slug: Optional[str]) -> A ) thumbprint = jose.json_util.encode_b64jose(self.jwk.thumbprint()) + # Queryset used for fetching accounts. The CA is used for loading its serial, which is used in URLs. + account_qs: AcmeAccountQuerySet = AcmeAccount.objects.url() + # RFC 8555, section 7.3: # # If this field is present with the value "true", then the server MUST NOT create a new account if @@ -545,7 +566,7 @@ def acme_request(self, message: messages.Registration, slug: Optional[str]) -> A # key (see Section 7.3.1). if message.only_return_existing: try: - account = AcmeAccount.objects.get(thumbprint=thumbprint, pem=pem) + account = await account_qs.aget(thumbprint=thumbprint, pem=pem) return AcmeResponseAccount(self.request, account) except AcmeAccount.DoesNotExist as ex: # RFC 8555, section 7.3: @@ -561,7 +582,7 @@ def acme_request(self, message: messages.Registration, slug: Optional[str]) -> A # and provide the URL of that account in the Location header field. try: # NOTE: Filter for thumbprint too b/c index for the field should speed up lookups. - account = AcmeAccount.objects.get(ca=self.ca, thumbprint=thumbprint, pem=pem) + account = await account_qs.aget(ca=self.ca, thumbprint=thumbprint, pem=pem) return AcmeResponseAccount(self.request, account) except AcmeAccount.DoesNotExist: pass @@ -596,8 +617,9 @@ def acme_request(self, message: messages.Registration, slug: Optional[str]) -> A # Call full_clean() so that model validation can do its magic try: - account.full_clean() - account.save() + # NOTE: full_clean() does not have an async version yet + await sync_to_async(account.full_clean)() + await account.asave() except ValidationError as ex: # Add a pretty list of validation errors to the detail field in the response subproblems = ", ".join( @@ -642,27 +664,33 @@ def is_account_usable(self, account: AcmeAccount) -> bool: return account.status == AcmeAccount.STATUS_VALID @transaction.atomic - def acme_request(self, message: messages.Registration, slug: Optional[str] = None) -> AcmeResponseAccount: - account = AcmeAccount.objects.get(slug=slug) + def _deactivate_account(self, account: AcmeAccount) -> None: + # RFC 8555, section 7.3.6 - Account Deactivation + log.info("Deactivating account %s", account.slug) + account.status = AcmeAccount.STATUS_DEACTIVATED + account.save() + + # Cancel all pending operations + account.orders.filter(status=AcmeOrder.STATUS_PENDING).update(status=AcmeOrder.STATUS_INVALID) + AcmeAuthorization.objects.filter( + order__account=account, status=AcmeAuthorization.STATUS_PENDING + ).update(status=AcmeAuthorization.STATUS_DEACTIVATED) + + async def acme_request( + self, message: messages.Registration, slug: Optional[str] = None + ) -> AcmeResponseAccount: + # TODO: does this allow updating other peoples accounts!? + account = await AcmeAccount.objects.url().aget(slug=slug) if message.status == AcmeAccount.STATUS_DEACTIVATED: - # RFC 8555, section 7.3.6 - Account Deactivation - log.info("Deactivating account %s", account.slug) - account.status = AcmeAccount.STATUS_DEACTIVATED - account.save() - - # Cancel all pending operations - account.orders.filter(status=AcmeOrder.STATUS_PENDING).update(status=AcmeOrder.STATUS_INVALID) - AcmeAuthorization.objects.filter( - order__account=account, status=AcmeAuthorization.STATUS_PENDING - ).update(status=AcmeAuthorization.STATUS_DEACTIVATED) + await sync_to_async(self._deactivate_account)(account) elif message.contact: self.validate_contacts(message) account.contact = "\n".join(message.contact) - account.save() + await account.asave() elif message.terms_of_service_agreed is not None: account.terms_of_service_agreed = message.terms_of_service_agreed - account.save() + await account.asave() else: raise AcmeMalformed(message="Only contact information can be updated.") @@ -673,7 +701,7 @@ class AcmeAccountOrdersView(AcmeBaseView): """View showing orders for an account (not yet implemented).""" # TODO: implement this view - def process_acme_request(self, slug: Optional[str]) -> AcmeResponse: # pragma: no cover + async def process_acme_request(self, slug: Optional[str]) -> AcmeResponse: # pragma: no cover raise AcmeException(message="Not Implemented.") @@ -694,7 +722,19 @@ class AcmeNewOrderView(AcmeMessageBaseView[NewOrder]): message_cls = NewOrder @transaction.atomic - def acme_request(self, message: NewOrder, slug: Optional[str] = None) -> AcmeResponseOrderCreated: + def _create_order( + self, + not_before: Optional[datetime], + not_after: Optional[datetime], + identifiers: list[messages.Identifier], + ) -> tuple[AcmeOrder, list[str]]: + order = AcmeOrder.objects.create(account=self.account, not_before=not_before, not_after=not_after) + authorizations = [ + self.request.build_absolute_uri(authz.acme_url) for authz in order.add_authorizations(identifiers) + ] + return order, authorizations + + async def acme_request(self, message: NewOrder, slug: Optional[str] = None) -> AcmeResponseOrderCreated: """Process ACME request.""" now = datetime.now(tz.utc) @@ -721,10 +761,7 @@ def acme_request(self, message: NewOrder, slug: Optional[str] = None) -> AcmeRes if not_after is not None and timezone.is_aware(not_after): not_after = timezone.make_naive(not_after) - order = AcmeOrder.objects.create(account=self.account, not_before=not_before, not_after=not_after) - authorizations = [ - self.request.build_absolute_uri(authz.acme_url) for authz in order.add_authorizations(identifiers) - ] + order, authorizations = await sync_to_async(self._create_order)(not_before, not_after, identifiers) expires = order.expires if expires.tzinfo is None: # acme.messages.Order requires a timezone-aware object @@ -755,9 +792,10 @@ class AcmeOrderView(AcmePostAsGetView): .. seealso:: `RFC 8555, 7.4 `_ """ - def acme_request(self, slug: str) -> AcmeResponseOrder: + async def acme_request(self, slug: str) -> AcmeResponseOrder: try: - order = AcmeOrder.objects.viewable().account(self.account).get(slug=slug) + order_qs = AcmeOrder.objects.viewable().account(self.account).url() + order = await order_qs.aget(slug=slug) except AcmeOrder.DoesNotExist as ex: # RFC 8555, section 10.5: Avoid leaking info that this slug does not exist by # return a normal unauthorized message. @@ -768,7 +806,7 @@ def acme_request(self, slug: str) -> AcmeResponseOrder: if expires.tzinfo is None: # acme.messages.Order requires a timezone-aware object expires = expires.replace(tzinfo=tz.utc) - authorizations = order.authorizations.all() + authorizations = order.authorizations.url() # type: ignore[attr-defined] if order.status in [AcmeOrder.STATUS_VALID, AcmeOrder.STATUS_INVALID]: # RFC 8555, section 7.1.3: # @@ -777,7 +815,7 @@ def acme_request(self, slug: str) -> AcmeResponseOrder: cert_url = None try: - cert = AcmeCertificate.objects.get(order=order) + cert = await AcmeCertificate.objects.select_related("cert").url().aget(order=order) if cert.cert and order.status == AcmeOrder.STATUS_VALID: # WARNING: certbot (at least version 0.31.0) will try to fetch the certificate immediately if # we return the URL. That view will fail if the certificate is not yet issued, and certbot @@ -788,6 +826,9 @@ def acme_request(self, slug: str) -> AcmeResponseOrder: except AcmeCertificate.DoesNotExist: pass + # Asynchronously fetch authorizations + authorizations = [a async for a in authorizations] + response = AcmeResponseOrder( status=order.status, expires=expires, @@ -864,10 +905,24 @@ def validate_csr(self, message: CertificateRequest, authorizations: Iterable[Acm return csr.public_bytes(Encoding.PEM).decode("utf-8") - def acme_request(self, message: CertificateRequest, slug: Optional[str]) -> AcmeResponseOrder: + @transaction.atomic + def create_certificate(self, order: AcmeOrder, csr: str) -> None: + """Create certificate and update order in a transaction.""" + # Create AcmeCertificate object (at this point without cert, as it hasn't been issued yet) + cert = AcmeCertificate.objects.create(order=order, csr=csr) + + # Update the status of the order to "processing" + order.status = AcmeOrder.STATUS_PROCESSING + order.save() + + # start task only after commit, see: + # https://docs.djangoproject.com/en/dev/topics/db/transactions/#django.db.transaction.on_commit + transaction.on_commit(lambda: run_task(acme_issue_certificate, acme_certificate_pk=cert.pk)) + + async def acme_request(self, message: CertificateRequest, slug: Optional[str]) -> AcmeResponseOrder: """Process ACME request.""" try: - order = AcmeOrder.objects.viewable().account(account=self.account).get(slug=slug) + order = await AcmeOrder.objects.viewable().account(account=self.account).url().aget(slug=slug) except AcmeOrder.DoesNotExist as ex: # RFC 8555, section 10.5: Avoid leaking info that this slug does not exist by # return a normal unauthorized message. @@ -895,8 +950,8 @@ def acme_request(self, message: CertificateRequest, slug: Optional[str]) -> Acme if expires.tzinfo is None: # acme.messages.Order requires a timezone-aware object expires = expires.replace(tzinfo=tz.utc) - authorizations = order.authorizations.all() - for auth in authorizations: + authorizations = order.authorizations.url().all() # type: ignore[attr-defined] + async for auth in authorizations: if auth.status != AcmeAuthorization.STATUS_VALID: # This is a state that should never happen in practice, because the order is only marked as # ready once all authorizations are valid. @@ -907,16 +962,7 @@ def acme_request(self, message: CertificateRequest, slug: Optional[str]) -> Acme # Parse and validate the CSR csr = self.validate_csr(message, authorizations) - # Create AcmeCertificate object (at this point without cert, as it hasn't been issued yet) - cert = AcmeCertificate.objects.create(order=order, csr=csr) - - # Update the status of the order to "processing" - order.status = AcmeOrder.STATUS_PROCESSING - order.save() - - # start task only after commit, see: - # https://docs.djangoproject.com/en/dev/topics/db/transactions/#django.db.transaction.on_commit - transaction.on_commit(lambda: run_task(acme_issue_certificate, acme_certificate_pk=cert.pk)) + await sync_to_async(self.create_certificate)(order, csr) response = AcmeResponseOrder( status=order.status, @@ -938,9 +984,10 @@ class AcmeCertificateView(AcmePostAsGetView): # This is the only view that does not return JSON, thus acme_request() returns the superclass # HttpResponse, and not an AcmeResponse (which is always JSON). - def acme_request(self, slug: str) -> HttpResponse: # type: ignore[override] + async def acme_request(self, slug: str) -> HttpResponse: # type: ignore[override] + cert_qs = AcmeCertificate.objects.viewable().account(self.account).select_related("cert__ca") try: - cert = AcmeCertificate.objects.viewable().account(self.account).get(slug=slug) + cert = await cert_qs.aget(slug=slug) except AcmeCertificate.DoesNotExist as ex: raise AcmeUnauthorized() from ex @@ -974,11 +1021,12 @@ class AcmeAuthorizationView(AcmePostAsGetView): .. seealso:: `RFC 8555, 7.5 `_ """ - def acme_request(self, slug: str) -> AcmeResponseAuthorization: + async def acme_request(self, slug: str) -> AcmeResponseAuthorization: # TODO: implement deactivating an authorization (section 7.5.2) + auth_qs = AcmeAuthorization.objects.viewable().account(account=self.account).url() try: - auth = AcmeAuthorization.objects.viewable().account(account=self.account).url().get(slug=slug) + auth = await auth_qs.aget(slug=slug) except AcmeAuthorization.DoesNotExist as ex: # RFC 8555, section 10.5: Avoid leaking info that this slug does not exist by # return a normal unauthorized message. @@ -986,7 +1034,7 @@ def acme_request(self, slug: str) -> AcmeResponseAuthorization: # self.prepared['order'] = auth.order.slug # self.prepared['auth'] = auth.slug - challenges = auth.get_challenges() + challenges = await auth.aget_challenges() expires = auth.expires if expires.tzinfo is None: # acme.Order requires a timezone-aware object @@ -1041,9 +1089,20 @@ def set_link_relations(self, response: "HttpResponseBase", **kwargs: str) -> Non kwargs["up"] = self.auth.acme_url super().set_link_relations(response, **kwargs) - def acme_request(self, slug: str) -> AcmeResponseChallenge: + @transaction.atomic + def save_challenge(self, challenge: AcmeChallenge) -> None: + """Save challenge and launch Celery task.""" + challenge.save() + + # Actually perform challenge validation asynchronously + # start task only after commit, see: + # https://docs.djangoproject.com/en/2.2/topics/db/transactions/#django.db.transaction.on_commit + transaction.on_commit(lambda: run_task(acme_validate_challenge, challenge.pk)) + + async def acme_request(self, slug: str) -> AcmeResponseChallenge: + challenge_qs = AcmeChallenge.objects.viewable().account(self.account).url() try: - challenge = AcmeChallenge.objects.viewable().account(self.account).url().get(slug=slug) + challenge = await challenge_qs.aget(slug=slug) except AcmeChallenge.DoesNotExist as ex: # RFC 8555, section 10.5: Avoid leaking info that this slug does not exist by # return a normal unauthorized message. @@ -1061,12 +1120,8 @@ def acme_request(self, slug: str) -> AcmeResponseChallenge: # # They transition to the "processing" state when the client responds to the challenge challenge.status = AcmeChallenge.STATUS_PROCESSING - challenge.save() - # Actually perform challenge validation asynchronously - # start task only after commit, see: - # https://docs.djangoproject.com/en/2.2/topics/db/transactions/#django.db.transaction.on_commit - transaction.on_commit(lambda: run_task(acme_validate_challenge, challenge.pk)) + await sync_to_async(self.save_challenge)(challenge) return AcmeResponseChallenge( chall=challenge.acme_challenge, @@ -1087,7 +1142,7 @@ class AcmeCertificateRevocationView(AcmeMessageBaseView[messages.Revocation]): accepts_kid_or_jwk = True message_cls = messages.Revocation - def get_certificate(self, serial: str) -> Certificate: + async def get_certificate(self, serial: str) -> Certificate: """Get the certificate that is to be revoked by this request. This function handles the special authorization requirements for this request (they can be signed by @@ -1106,7 +1161,7 @@ def get_certificate(self, serial: str) -> Certificate: # This implies that the account used to request the certificate may even be revoked or invalid, as # long as the private key is used to sign the request. So we don't look at the ACME account at all # here. - cert: Certificate = certs.get(serial=serial) + cert: Certificate = await certs.aget(serial=serial) jwk = cert.jwk @@ -1123,7 +1178,9 @@ def get_certificate(self, serial: str) -> Certificate: else: # Get the certificate by serial if it *has* an ACME account. # NOTE: The base class already makes sure that the account is currently valid. - cert = certs.filter(acmecertificate__order__account__isnull=False).get(serial=serial) + certs = certs.filter(acmecertificate__order__account__isnull=False) + certs = certs.select_related("acmecertificate__order__account") + cert = await certs.aget(serial=serial) # If the request is from the account that issued the certificate, the certificate can be revoked. # NOTE: self.account is **only** set if the request has no JWK. @@ -1132,7 +1189,10 @@ def get_certificate(self, serial: str) -> Certificate: # If the account holds authorizations for all the identifiers in the certificate, it can also # be revoked, so get a list of all currently valid authorizations that the account holds - authz = set(AcmeAuthorization.objects.dns().valid().account(account=self.account).names()) + authz_qs = AcmeAuthorization.objects.dns().valid().account(account=self.account).names() + + # PYLINT NOTE: set() does not allow an async generator, pylint does not detect this. + authz = set([auth async for auth in authz_qs]) # pylint: disable=consider-using-set-comprehension # Get names from the certificate, first from the CommonName... # NOTE: returns empty list if subject does not have a CommonName. @@ -1155,7 +1215,7 @@ def get_certificate(self, serial: str) -> Certificate: return cert - def acme_request(self, message: messages.Revocation, slug: Optional[str]) -> AcmeResponse: + async def acme_request(self, message: messages.Revocation, slug: Optional[str]) -> AcmeResponse: reason_code = message.reason if reason_code is None: reason_code = 0 @@ -1175,7 +1235,7 @@ def acme_request(self, message: messages.Revocation, slug: Optional[str]) -> Acm raise AcmeMalformed(message="Request did not contain a certificate.") try: - cert = self.get_certificate(int_to_hex(cg_cert.serial_number)) + cert = await self.get_certificate(int_to_hex(cg_cert.serial_number)) except Certificate.DoesNotExist as ex: raise AcmeUnauthorized(message="Certificate not found.") from ex @@ -1191,5 +1251,5 @@ def acme_request(self, message: messages.Revocation, slug: Optional[str]) -> Acm raise AcmeMalformed(typ="alreadyRevoked", message="Certificate was already revoked.") # Finally actually revoke the certificate - cert.revoke(reason) + await cert.arevoke(reason) return AcmeResponse({}) # No response specified in RFC 8555! diff --git a/ca/django_ca/managers.py b/ca/django_ca/managers.py index 193590bb8..6c9f791a8 100644 --- a/ca/django_ca/managers.py +++ b/ca/django_ca/managers.py @@ -16,7 +16,7 @@ import typing from collections.abc import Iterable from datetime import datetime, timedelta, timezone as tz -from typing import Any, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union from asgiref.sync import sync_to_async from pydantic import BaseModel @@ -41,6 +41,7 @@ from django_ca.openssh import SshHostCaExtension, SshUserCaExtension from django_ca.profiles import Profile, profiles from django_ca.pydantic.validators import crl_scope_validator +from django_ca.querysets import AcmeCertificateQuerySet from django_ca.signals import post_create_ca, post_issue_cert, pre_create_ca from django_ca.typehints import ( AllowedHashTypes, @@ -67,6 +68,7 @@ from django_ca.querysets import ( AcmeAccountQuerySet, AcmeAuthorizationQuerySet, + AcmeOrderQuerySet, CertificateAuthorityQuerySet, CertificateQuerySet, CertificateRevocationListQuerySet, @@ -930,12 +932,25 @@ class AcmeAccountManager(AcmeAccountManagerBase): # # pylint: disable=missing-function-docstring; just defining stubs here + def url(self) -> "AcmeAccountQuerySet": ... + def viewable(self) -> "AcmeAccountQuerySet": ... class AcmeOrderManager(AcmeOrderManagerBase): """Model manager for :py:class:`~django_ca.models.AcmeOrder`.""" + if TYPE_CHECKING: + # See CertificateManagerMixin for description on this branch + # + # pylint: disable=missing-function-docstring,unused-argument; just defining stubs here + + def account(self, account: "AcmeAccount") -> "AcmeOrderQuerySet": ... + + def url(self) -> "AcmeOrderQuerySet": ... + + def viewable(self) -> "AcmeOrderQuerySet": ... + class AcmeAuthorizationManager(AcmeAuthorizationManagerBase): """Model manager for :py:class:`~django_ca.models.AcmeAuthorization`.""" @@ -962,3 +977,13 @@ class AcmeChallengeManager(AcmeChallengeManagerBase): class AcmeCertificateManager(AcmeCertificateManagerBase): """Model manager for :py:class:`~django_ca.models.AcmeCertificate`.""" + + if typing.TYPE_CHECKING: + # See CertificateManagerMixin for description on this branch + # + # pylint: disable=missing-function-docstring; just defining stubs here + def account(self) -> "AcmeCertificateQuerySet": ... + + def url(self) -> "AcmeCertificateQuerySet": ... + + def viewalbe(self) -> "AcmeCertificateQuerySet": ... diff --git a/ca/django_ca/models.py b/ca/django_ca/models.py index 7b093dc16..9f7e8f7fa 100644 --- a/ca/django_ca/models.py +++ b/ca/django_ca/models.py @@ -1644,6 +1644,18 @@ def get_challenges(self) -> list["AcmeChallenge"]: AcmeChallenge.objects.get_or_create(auth=self, type=AcmeChallenge.TYPE_DNS_01)[0], ] + async def aget_challenges(self) -> list["AcmeChallenge"]: + """Get list of :py:class:`~django_ca.models.AcmeChallenge` objects for this authorization. + + Note that challenges will be created if they don't exist. + """ + qs = AcmeChallenge.objects.url() + return [ + (await qs.aget_or_create(auth=self, type=AcmeChallenge.TYPE_HTTP_01))[0], + # AcmeChallenge.objects.get_or_create(auth=self, type=AcmeChallenge.TYPE_TLS_ALPN_01)[0], + (await qs.aget_or_create(auth=self, type=AcmeChallenge.TYPE_DNS_01))[0], + ] + @property def usable(self) -> bool: """Boolean defining if an authentication can still can be used in order validation. diff --git a/ca/django_ca/querysets.py b/ca/django_ca/querysets.py index 53059bf7e..698fba85c 100644 --- a/ca/django_ca/querysets.py +++ b/ca/django_ca/querysets.py @@ -23,7 +23,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db import models -from django.db.models import Q +from django.db.models import Q, QuerySet from django.utils import timezone from django_ca.acme.constants import Status @@ -190,8 +190,6 @@ def default(self) -> "CertificateAuthority": or not currently valid. Or, if the setting is not set, no CA is currently usable. """ if (serial := model_settings.CA_DEFAULT_CA) is not None: - now = timezone.now() - try: # NOTE: Don't prefilter queryset so that we can provide more specialized error messages below. ca = self.get(serial=serial) @@ -200,6 +198,8 @@ def default(self) -> "CertificateAuthority": if ca.enabled is False: raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial} is disabled.") + + now = timezone.now() if ca.not_after < now: raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial} is expired.") if ca.not_before > now: # OK, how could this ever happen? ;-) @@ -213,6 +213,43 @@ def default(self) -> "CertificateAuthority": raise ImproperlyConfigured("No CA is currently usable.") return first_ca + async def adefault(self) -> "CertificateAuthority": + """Return the default CA to use when no CA is selected. + + This function honors the :ref:`CA_DEFAULT_CA `. If no usable CA can be + returned, raises :py:exc:`~django:django.core.exceptions.ImproperlyConfigured`. + + Raises + ------ + :py:exc:`~django:django.core.exceptions.ImproperlyConfigured` + When the CA named by :ref:`CA_DEFAULT_CA ` is either not found, disabled + or not currently valid. Or, if the setting is not set, no CA is currently usable. + """ + if (serial := model_settings.CA_DEFAULT_CA) is not None: + try: + # NOTE: Don't prefilter queryset so that we can provide more specialized error messages below. + ca = await self.aget(serial=serial) + except self.model.DoesNotExist as ex: + raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial}: CA not found.") from ex + + if ca.enabled is False: + raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial} is disabled.") + + now = timezone.now() + if ca.not_after < now: + raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial} is expired.") + if ca.not_before > now: # OK, how could this ever happen? ;-) + raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial} is not yet valid.") + return ca + + # NOTE: We add the serial to sorting make *sure* we have deterministic behavior. In many cases, users + # will just create several CAs that all actually expire on the same day. + first_ca_qs = self.usable().order_by("-not_after", "serial") # usable == enabled and valid + first_ca = await first_ca_qs.afirst() + if first_ca is None: + raise ImproperlyConfigured("No CA is currently usable.") + return first_ca + def disabled(self) -> "CertificateAuthorityQuerySet": """Return CAs that are disabled.""" return self.filter(enabled=False) @@ -311,6 +348,10 @@ def scope( class AcmeAccountQuerySet(AcmeAccountQuerySetBase): """QuerySet for :py:class:`~django_ca.models.AcmeAccount`.""" + def url(self) -> "AcmeAccountQuerySet": + """Assure that returned models can build an ACME URL without additional database queries.""" + return self.select_related("ca") + def viewable(self) -> "AcmeAccountQuerySet": """Filter ACME accounts that can be viewed via the ACME API. @@ -330,6 +371,10 @@ def account(self, account: "AcmeAccount") -> "AcmeOrderQuerySet": """Filter orders belonging to the given account.""" return self.filter(account=account) + def url(self) -> "AcmeOrderQuerySet": + """Assure that returned models can build an ACME URL without additional database queries.""" + return self.select_related("account__ca") + def viewable(self) -> "AcmeOrderQuerySet": """Filter ACME orders that can be viewed via the ACME API. @@ -355,9 +400,9 @@ def dns(self) -> "AcmeAuthorizationQuerySet": """Get all authorizations of type DNS.""" return self.filter(type=self.model.TYPE_DNS) - def names(self) -> list[str]: + def names(self) -> QuerySet["AcmeAuthorization", str]: """Get a flat list of names identified by the current queryset.""" - return list(self.values_list("value", flat=True)) + return self.values_list("value", flat=True) def url(self) -> "AcmeAuthorizationQuerySet": """Prepare queryset to get the ACME URL of objects without subsequent database lookups.""" @@ -414,7 +459,7 @@ def account(self, account: "AcmeAccount") -> "AcmeCertificateQuerySet": return self.filter(order__account=account) def url(self) -> "AcmeCertificateQuerySet": - """Prepare queryset to get the ACME URL of objects without subsequent database lookups.""" + """Assure that returned models can build an ACME URL without additional database queries.""" return self.select_related("order__account__ca") def viewable(self) -> "AcmeCertificateQuerySet":