Skip to content

Commit

Permalink
fix typehints and missing docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
mathiasertl committed Dec 29, 2024
1 parent 75bc303 commit e110880
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 22 deletions.
4 changes: 3 additions & 1 deletion ca/django_ca/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def for_certificate_revocation_list(

def get_by_serial_or_cn(self, identifier: str) -> X509CertMixinTypeVar: ...

async def aget_by_serial_or_cn(self, identifier: str) -> X509CertMixinTypeVar: ...

def valid(self) -> QuerySetTypeVar: ...


Expand Down Expand Up @@ -883,7 +885,7 @@ def create_certificate_revocation_list(
)

# Create database object (as late as possible so any exception above would not hit the database)
obj: CertificateAuthority = self.create(
obj: CertificateRevocationList = self.create(
ca=ca,
number=Coalesce(models.Subquery(number_subquery, default=1), 0),
only_contains_ca_certs=only_contains_ca_certs,
Expand Down
5 changes: 3 additions & 2 deletions ca/django_ca/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def get_revocation(self) -> x509.RevokedCertificate:
@contextmanager
def _revoke(
self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None
) -> None:
) -> Iterator[None]:
pre_revoke_cert.send(sender=self.__class__, cert=self, reason=reason)

self.revoked = True
Expand Down Expand Up @@ -485,6 +485,7 @@ def revoke(
async def arevoke(
self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None
) -> None:
"""Asynchronous version of ``revoke()``."""
with self._revoke(reason, compromised):
await self.asave()

Expand Down Expand Up @@ -1257,7 +1258,7 @@ def _cache_data(self) -> Iterator[tuple[str, bytes, int]]:

now = datetime.now(tz=tz.utc)
if self.loaded.next_update_utc is not None:
expires_seconds = (self.loaded.next_update_utc - now).total_seconds()
expires_seconds = int((self.loaded.next_update_utc - now).total_seconds())
else: # pragma: no cover # we never generate CRLs without a next_update flag.
expires_seconds = 86400

Expand Down
9 changes: 8 additions & 1 deletion ca/django_ca/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,23 @@ def filter(self: X509CertMixinQuerySetProtocol) -> X509CertMixinQuerySetProtocol

model: X509CertMixinTypeVar

async def aget(self, *args: Any, **kwargs: Any) -> X509CertMixinTypeVar: ...

def filter(self, *args: Any, **kwargs: Any) -> "Self": ...

def get(self, *args: Any, **kwargs: Any) -> X509CertMixinTypeVar: ...

def _serial_or_cn_query(self, identifier: str) -> tuple[Q, Q]: ...

def revoked(self) -> "Self": ...


class DjangoCAMixin(Generic[X509CertMixinTypeVar], metaclass=abc.ABCMeta):
"""Mixin with common methods for CertificateAuthority and Certificate models."""

def _serial_or_cn_query(self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str):
def _serial_or_cn_query(
self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str
) -> tuple[Q, Q]:
identifier = identifier.strip()
exact_query = startswith_query = Q(cn=identifier)

Expand Down Expand Up @@ -130,6 +136,7 @@ def get_by_serial_or_cn(
async def aget_by_serial_or_cn(
self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str
) -> X509CertMixinTypeVar:
"""Asynchronous version of :py:fucn:`~django_ca.querysets.DjangoCAMixin.get_by_serial_or_cn()`."""
exact_query, startswith_query = self._serial_or_cn_query(identifier)

try:
Expand Down
6 changes: 3 additions & 3 deletions ca/django_ca/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,11 +558,11 @@ def test_wrong_values(self) -> None:

for key_type in ("Ed448", "Ed25519"):
with pytest.raises(ValueError, match=rf"^Key size is not supported for {key_type} keys\.$"):
validate_private_key_parameters(key_type, key_size, None)
validate_private_key_parameters(key_type, key_size, None) # type: ignore[call-overload]
with pytest.raises(
ValueError, match=rf"^Elliptic curves are not supported for {key_type} keys\.$"
):
validate_private_key_parameters(key_type, None, elliptic_curve)
validate_private_key_parameters(key_type, None, elliptic_curve) # type: ignore[call-overload]


class ValidatePublicKeyParametersTest(TestCase):
Expand All @@ -572,7 +572,7 @@ def test_valid_parameters(self) -> None:
"""Test valid parameters."""
for key_type in ("RSA", "DSA", "EC"):
for algorithm in (hashes.SHA256(), hashes.SHA512()):
validate_public_key_parameters(key_type, algorithm)
validate_public_key_parameters(key_type, algorithm) # type: ignore[arg-type]
for key_type in ("Ed448", "Ed25519"):
validate_public_key_parameters(key_type, None) # type: ignore[arg-type]

Expand Down
29 changes: 14 additions & 15 deletions ca/django_ca/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def fetch_crl(

# CRL is not cached, try to retrieve it from the database.
if encoded_crl is None:
crl_qs: Optional[CertificateRevocationList] = (
crl_qs = (
CertificateRevocationList.objects.scope(
ca=ca,
only_contains_ca_certs=only_contains_ca_certs,
Expand All @@ -187,7 +187,7 @@ async def fetch_crl(
.filter(data__isnull=False) # only objects that have CRL data associated with it
.select_related("ca")
)
crl_obj = await crl_qs.anewest()
crl_obj: Optional[CertificateRevocationList] = await crl_qs.anewest()

# CRL was not found in the database either, so we try to regenerate it.
if crl_obj is None:
Expand All @@ -214,7 +214,7 @@ async def fetch_crl(

return encoded_crl

async def get(self, request: HttpRequest, serial: str) -> HttpResponse: # pylint: disable=unused-argument
async def get(self, request: HttpRequest, serial: str) -> HttpResponse:
# pylint: disable=missing-function-docstring; standard Django view function
if get_encoding := request.GET.get("encoding"):
if get_encoding not in CERTIFICATE_REVOCATION_LIST_ENCODING_TYPES:
Expand Down Expand Up @@ -275,6 +275,8 @@ class OCSPView(View):
ca_ocsp = False
"""If set to ``True``, validate child CAs instead."""

loaded_ca: CertificateAuthority

async def get(self, request: HttpRequest, data: str) -> HttpResponse:
# pylint: disable=missing-function-docstring; standard Django view function
try:
Expand Down Expand Up @@ -405,7 +407,7 @@ async def process_ocsp_request(self, data: bytes) -> HttpResponse:

# Get CA and certificate
try:
ca = await self.get_ca()
ca = self.loaded_ca = await self.get_ca()
except CertificateAuthority.DoesNotExist:
log.error("%s: Certificate Authority could not be found.", self.ca)
return self.fail()
Expand Down Expand Up @@ -474,34 +476,31 @@ class GenericOCSPView(OCSPView):
argument must be the serial for this CA.
"""

auto_ca: CertificateAuthority

# NOINSPECTION NOTE: It's okay to be more specific here
# noinspection PyMethodOverriding
async def dispatch(self, request: HttpRequest, serial: str, **kwargs: Any) -> "HttpResponseBase":
def dispatch(self, request: HttpRequest, serial: str, **kwargs: Any) -> "HttpResponseBase":
if request.method == "GET" and "data" not in kwargs:
return await self.http_method_not_allowed(request, serial, **kwargs)
return self.http_method_not_allowed(request, serial, **kwargs)
if request.method == "POST" and "data" in kwargs:
return await self.http_method_not_allowed(request, serial, **kwargs)
return self.http_method_not_allowed(request, serial, **kwargs)

# COVERAGE NOTE: Checking just for safety here.
if not isinstance(serial, str): # pragma: no cover
raise ImproperlyConfigured("View expects a str for a serial")

self.auto_ca = await CertificateAuthority.objects.aget(serial=serial)
return await super().dispatch(request, **kwargs)
return super().dispatch(request, **kwargs)

async def get_ca(self) -> CertificateAuthority:
return self.auto_ca
return await CertificateAuthority.objects.aget(serial=self.kwargs["serial"])

def get_expires(self, now: datetime) -> datetime:
return now + timedelta(seconds=self.auto_ca.ocsp_response_validity)
return now + timedelta(seconds=self.loaded_ca.ocsp_response_validity)

async def get_ocsp_response(self, builder: OCSPResponseBuilder) -> Union[HttpResponse, OCSPResponse]:
"""Sign the OCSP request using cryptography keys."""
# Load public key
try:
responder_pem = self.auto_ca.ocsp_key_backend_options["certificate"]["pem"]
responder_pem = self.loaded_ca.ocsp_key_backend_options["certificate"]["pem"]
except KeyError:
# The OCSP responder certificate has never been created. `manage.py init_ca` usually creates them,
# so this can only happen if the system is misconfigured (e.g. Celery task is never acted upon),
Expand All @@ -525,7 +524,7 @@ async def get_ocsp_response(self, builder: OCSPResponseBuilder) -> Union[HttpRes
# TYPEHINT NOTE: Certificates are always generated with a supported algorithm, so we do not check.
algorithm = cast(Optional[AllowedHashTypes], responder_certificate.signature_hash_algorithm)

return self.auto_ca.ocsp_key_backend.sign_ocsp_response(self.auto_ca, builder, algorithm)
return self.loaded_ca.ocsp_key_backend.sign_ocsp_response(self.loaded_ca, builder, algorithm)


class GenericCAIssuersView(View):
Expand Down

0 comments on commit e110880

Please sign in to comment.