From e110880b0f79bfc855dbe25e670b30bb0f7aa075 Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Sun, 29 Dec 2024 11:31:10 +0100 Subject: [PATCH] fix typehints and missing docstrings --- ca/django_ca/managers.py | 4 +++- ca/django_ca/models.py | 5 +++-- ca/django_ca/querysets.py | 9 ++++++++- ca/django_ca/tests/test_utils.py | 6 +++--- ca/django_ca/views.py | 29 ++++++++++++++--------------- 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/ca/django_ca/managers.py b/ca/django_ca/managers.py index ed46cd33d..193590bb8 100644 --- a/ca/django_ca/managers.py +++ b/ca/django_ca/managers.py @@ -122,6 +122,8 @@ def for_certificate_revocation_list( def get_by_serial_or_cn(self, identifier: str) -> X509CertMixinTypeVar: ... + async def aget_by_serial_or_cn(self, identifier: str) -> X509CertMixinTypeVar: ... + def valid(self) -> QuerySetTypeVar: ... @@ -883,7 +885,7 @@ def create_certificate_revocation_list( ) # Create database object (as late as possible so any exception above would not hit the database) - obj: CertificateAuthority = self.create( + obj: CertificateRevocationList = self.create( ca=ca, number=Coalesce(models.Subquery(number_subquery, default=1), 0), only_contains_ca_certs=only_contains_ca_certs, diff --git a/ca/django_ca/models.py b/ca/django_ca/models.py index 246473c73..1b3d1e749 100644 --- a/ca/django_ca/models.py +++ b/ca/django_ca/models.py @@ -454,7 +454,7 @@ def get_revocation(self) -> x509.RevokedCertificate: @contextmanager def _revoke( self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None - ) -> None: + ) -> Iterator[None]: pre_revoke_cert.send(sender=self.__class__, cert=self, reason=reason) self.revoked = True @@ -485,6 +485,7 @@ def revoke( async def arevoke( self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None ) -> None: + """Asynchronous version of ``revoke()``.""" with self._revoke(reason, compromised): await self.asave() @@ -1257,7 +1258,7 @@ def _cache_data(self) -> Iterator[tuple[str, bytes, int]]: now = datetime.now(tz=tz.utc) if self.loaded.next_update_utc is not None: - expires_seconds = (self.loaded.next_update_utc - now).total_seconds() + expires_seconds = int((self.loaded.next_update_utc - now).total_seconds()) else: # pragma: no cover # we never generate CRLs without a next_update flag. expires_seconds = 86400 diff --git a/ca/django_ca/querysets.py b/ca/django_ca/querysets.py index 90d7516b2..9833e918b 100644 --- a/ca/django_ca/querysets.py +++ b/ca/django_ca/querysets.py @@ -85,17 +85,23 @@ def filter(self: X509CertMixinQuerySetProtocol) -> X509CertMixinQuerySetProtocol model: X509CertMixinTypeVar + async def aget(self, *args: Any, **kwargs: Any) -> X509CertMixinTypeVar: ... + def filter(self, *args: Any, **kwargs: Any) -> "Self": ... def get(self, *args: Any, **kwargs: Any) -> X509CertMixinTypeVar: ... + def _serial_or_cn_query(self, identifier: str) -> tuple[Q, Q]: ... + def revoked(self) -> "Self": ... class DjangoCAMixin(Generic[X509CertMixinTypeVar], metaclass=abc.ABCMeta): """Mixin with common methods for CertificateAuthority and Certificate models.""" - def _serial_or_cn_query(self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str): + def _serial_or_cn_query( + self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str + ) -> tuple[Q, Q]: identifier = identifier.strip() exact_query = startswith_query = Q(cn=identifier) @@ -130,6 +136,7 @@ def get_by_serial_or_cn( async def aget_by_serial_or_cn( self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str ) -> X509CertMixinTypeVar: + """Asynchronous version of :py:fucn:`~django_ca.querysets.DjangoCAMixin.get_by_serial_or_cn()`.""" exact_query, startswith_query = self._serial_or_cn_query(identifier) try: diff --git a/ca/django_ca/tests/test_utils.py b/ca/django_ca/tests/test_utils.py index c4c104409..4a34f09ef 100644 --- a/ca/django_ca/tests/test_utils.py +++ b/ca/django_ca/tests/test_utils.py @@ -558,11 +558,11 @@ def test_wrong_values(self) -> None: for key_type in ("Ed448", "Ed25519"): with pytest.raises(ValueError, match=rf"^Key size is not supported for {key_type} keys\.$"): - validate_private_key_parameters(key_type, key_size, None) + validate_private_key_parameters(key_type, key_size, None) # type: ignore[call-overload] with pytest.raises( ValueError, match=rf"^Elliptic curves are not supported for {key_type} keys\.$" ): - validate_private_key_parameters(key_type, None, elliptic_curve) + validate_private_key_parameters(key_type, None, elliptic_curve) # type: ignore[call-overload] class ValidatePublicKeyParametersTest(TestCase): @@ -572,7 +572,7 @@ def test_valid_parameters(self) -> None: """Test valid parameters.""" for key_type in ("RSA", "DSA", "EC"): for algorithm in (hashes.SHA256(), hashes.SHA512()): - validate_public_key_parameters(key_type, algorithm) + validate_public_key_parameters(key_type, algorithm) # type: ignore[arg-type] for key_type in ("Ed448", "Ed25519"): validate_public_key_parameters(key_type, None) # type: ignore[arg-type] diff --git a/ca/django_ca/views.py b/ca/django_ca/views.py index d1c0b2e34..c2602fa39 100644 --- a/ca/django_ca/views.py +++ b/ca/django_ca/views.py @@ -176,7 +176,7 @@ async def fetch_crl( # CRL is not cached, try to retrieve it from the database. if encoded_crl is None: - crl_qs: Optional[CertificateRevocationList] = ( + crl_qs = ( CertificateRevocationList.objects.scope( ca=ca, only_contains_ca_certs=only_contains_ca_certs, @@ -187,7 +187,7 @@ async def fetch_crl( .filter(data__isnull=False) # only objects that have CRL data associated with it .select_related("ca") ) - crl_obj = await crl_qs.anewest() + crl_obj: Optional[CertificateRevocationList] = await crl_qs.anewest() # CRL was not found in the database either, so we try to regenerate it. if crl_obj is None: @@ -214,7 +214,7 @@ async def fetch_crl( return encoded_crl - async def get(self, request: HttpRequest, serial: str) -> HttpResponse: # pylint: disable=unused-argument + async def get(self, request: HttpRequest, serial: str) -> HttpResponse: # pylint: disable=missing-function-docstring; standard Django view function if get_encoding := request.GET.get("encoding"): if get_encoding not in CERTIFICATE_REVOCATION_LIST_ENCODING_TYPES: @@ -275,6 +275,8 @@ class OCSPView(View): ca_ocsp = False """If set to ``True``, validate child CAs instead.""" + loaded_ca: CertificateAuthority + async def get(self, request: HttpRequest, data: str) -> HttpResponse: # pylint: disable=missing-function-docstring; standard Django view function try: @@ -405,7 +407,7 @@ async def process_ocsp_request(self, data: bytes) -> HttpResponse: # Get CA and certificate try: - ca = await self.get_ca() + ca = self.loaded_ca = await self.get_ca() except CertificateAuthority.DoesNotExist: log.error("%s: Certificate Authority could not be found.", self.ca) return self.fail() @@ -474,34 +476,31 @@ class GenericOCSPView(OCSPView): argument must be the serial for this CA. """ - auto_ca: CertificateAuthority - # NOINSPECTION NOTE: It's okay to be more specific here # noinspection PyMethodOverriding - async def dispatch(self, request: HttpRequest, serial: str, **kwargs: Any) -> "HttpResponseBase": + def dispatch(self, request: HttpRequest, serial: str, **kwargs: Any) -> "HttpResponseBase": if request.method == "GET" and "data" not in kwargs: - return await self.http_method_not_allowed(request, serial, **kwargs) + return self.http_method_not_allowed(request, serial, **kwargs) if request.method == "POST" and "data" in kwargs: - return await self.http_method_not_allowed(request, serial, **kwargs) + return self.http_method_not_allowed(request, serial, **kwargs) # COVERAGE NOTE: Checking just for safety here. if not isinstance(serial, str): # pragma: no cover raise ImproperlyConfigured("View expects a str for a serial") - self.auto_ca = await CertificateAuthority.objects.aget(serial=serial) - return await super().dispatch(request, **kwargs) + return super().dispatch(request, **kwargs) async def get_ca(self) -> CertificateAuthority: - return self.auto_ca + return await CertificateAuthority.objects.aget(serial=self.kwargs["serial"]) def get_expires(self, now: datetime) -> datetime: - return now + timedelta(seconds=self.auto_ca.ocsp_response_validity) + return now + timedelta(seconds=self.loaded_ca.ocsp_response_validity) async def get_ocsp_response(self, builder: OCSPResponseBuilder) -> Union[HttpResponse, OCSPResponse]: """Sign the OCSP request using cryptography keys.""" # Load public key try: - responder_pem = self.auto_ca.ocsp_key_backend_options["certificate"]["pem"] + responder_pem = self.loaded_ca.ocsp_key_backend_options["certificate"]["pem"] except KeyError: # The OCSP responder certificate has never been created. `manage.py init_ca` usually creates them, # so this can only happen if the system is misconfigured (e.g. Celery task is never acted upon), @@ -525,7 +524,7 @@ async def get_ocsp_response(self, builder: OCSPResponseBuilder) -> Union[HttpRes # TYPEHINT NOTE: Certificates are always generated with a supported algorithm, so we do not check. algorithm = cast(Optional[AllowedHashTypes], responder_certificate.signature_hash_algorithm) - return self.auto_ca.ocsp_key_backend.sign_ocsp_response(self.auto_ca, builder, algorithm) + return self.loaded_ca.ocsp_key_backend.sign_ocsp_response(self.loaded_ca, builder, algorithm) class GenericCAIssuersView(View):