diff --git a/ca/django_ca/tests/acme/views/base.py b/ca/django_ca/tests/acme/views/base.py
index bdeee2fe1..0574c602e 100644
--- a/ca/django_ca/tests/acme/views/base.py
+++ b/ca/django_ca/tests/acme/views/base.py
@@ -15,7 +15,6 @@
import abc
import typing
-from collections.abc import Iterator
from http import HTTPStatus
from typing import Optional, Union
from unittest import mock
@@ -255,7 +254,7 @@ class AcmeWithAccountViewTestCaseMixin(
"""Mixin that also adds accounts to the database."""
@pytest.fixture
- def main_account(self, account: AcmeAccount) -> Iterator[AcmeAccount]:
+ def main_account(self, account: AcmeAccount) -> AcmeAccount:
"""Return the main account to be used for this test case.
This is overwritten by the revocation test case.
diff --git a/ca/django_ca/tests/acme/views/conftest.py b/ca/django_ca/tests/acme/views/conftest.py
index 6c2bb947b..967fbaa47 100644
--- a/ca/django_ca/tests/acme/views/conftest.py
+++ b/ca/django_ca/tests/acme/views/conftest.py
@@ -15,8 +15,6 @@
# pylint: disable=redefined-outer-name
-from collections.abc import Iterator
-
from django.test import Client
import pytest
@@ -36,32 +34,32 @@
@pytest.fixture
-def account_slug() -> Iterator[str]:
+def account_slug() -> str:
"""Fixture for an account slug."""
return acme_slug()
@pytest.fixture
-def order_slug() -> Iterator[str]:
+def order_slug() -> str:
"""Fixture for an order slug."""
return acme_slug()
@pytest.fixture
-def acme_cert_slug() -> Iterator[str]:
+def acme_cert_slug() -> str:
"""Fixture for an ACME certificate slug."""
return acme_slug()
@pytest.fixture
-def client(client: Client) -> Iterator[Client]:
+def client(client: Client) -> Client:
"""Override client fixture to set the default server name."""
client.defaults["SERVER_NAME"] = SERVER_NAME
return client
@pytest.fixture
-def account(root: CertificateAuthority, account_slug: str, kid: str) -> Iterator[AcmeAccount]:
+def account(root: CertificateAuthority, account_slug: str, kid: str) -> AcmeAccount:
"""Fixture for an account."""
return AcmeAccount.objects.create(
ca=root,
@@ -75,25 +73,25 @@ def account(root: CertificateAuthority, account_slug: str, kid: str) -> Iterator
@pytest.fixture
-def kid(root: CertificateAuthority, account_slug: str) -> Iterator[str]:
+def kid(root: CertificateAuthority, account_slug: str) -> str:
"""Fixture for a full KID."""
return absolute_acme_uri(":acme-account", serial=root.serial, slug=account_slug)
@pytest.fixture
-def order(account: AcmeAccount, order_slug: str) -> Iterator[AcmeOrder]:
+def order(account: AcmeAccount, order_slug: str) -> AcmeOrder:
"""Fixture for an order."""
return AcmeOrder.objects.create(account=account, slug=order_slug)
@pytest.fixture
-def authz(order: AcmeOrder) -> Iterator[AcmeAuthorization]:
+def authz(order: AcmeOrder) -> AcmeAuthorization:
"""Fixture for an authorization."""
return AcmeAuthorization.objects.create(order=order, value=HOST_NAME)
@pytest.fixture
-def challenge(authz: AcmeAuthorization) -> Iterator[AcmeChallenge]:
+def challenge(authz: AcmeAuthorization) -> AcmeChallenge:
"""Fixture for a challenge."""
challenge = authz.get_challenges()[0]
challenge.token = "foobar"
@@ -102,6 +100,6 @@ def challenge(authz: AcmeAuthorization) -> Iterator[AcmeChallenge]:
@pytest.fixture
-def acme_cert(root_cert: Certificate, order: AcmeOrder, acme_cert_slug: str) -> Iterator[AcmeCertificate]:
+def acme_cert(root_cert: Certificate, order: AcmeOrder, acme_cert_slug: str) -> AcmeCertificate:
"""Fixture for an ACME certificate."""
return AcmeCertificate.objects.create(order=order, cert=root_cert, slug=acme_cert_slug)
diff --git a/ca/django_ca/tests/acme/views/test_authorization.py b/ca/django_ca/tests/acme/views/test_authorization.py
index dc657e270..1e1912c51 100644
--- a/ca/django_ca/tests/acme/views/test_authorization.py
+++ b/ca/django_ca/tests/acme/views/test_authorization.py
@@ -15,7 +15,6 @@
# pylint: disable=redefined-outer-name # because of fixtures
-from collections.abc import Iterator
from http import HTTPStatus
import josepy as jose
@@ -41,13 +40,13 @@
@pytest.fixture
-def url(authz: AcmeAuthorization) -> Iterator[str]:
+def url(authz: AcmeAuthorization) -> str:
"""URL under test."""
return root_reverse("acme-authz", slug=authz.slug)
@pytest.fixture
-def message() -> Iterator[bytes]:
+def message() -> bytes:
"""Yield an empty bytestring, since this is a POST-AS-GET request."""
return b""
diff --git a/ca/django_ca/tests/acme/views/test_challenge.py b/ca/django_ca/tests/acme/views/test_challenge.py
index 504113207..283fd87c1 100644
--- a/ca/django_ca/tests/acme/views/test_challenge.py
+++ b/ca/django_ca/tests/acme/views/test_challenge.py
@@ -16,7 +16,6 @@
# pylint: disable=redefined-outer-name # because of fixtures
import unittest
-from collections.abc import Iterator
from http import HTTPStatus
from typing import Optional
from unittest import mock
@@ -42,13 +41,13 @@
@pytest.fixture
-def url(challenge: AcmeChallenge) -> Iterator[str]:
+def url(challenge: AcmeChallenge) -> str:
"""URL under test."""
return root_reverse("acme-challenge", slug=challenge.slug)
@pytest.fixture
-def message() -> Iterator[bytes]:
+def message() -> bytes:
"""Yield an empty bytestring, since this is a POST-AS-GET request."""
return b""
diff --git a/ca/django_ca/tests/acme/views/test_new_account.py b/ca/django_ca/tests/acme/views/test_new_account.py
index d29b114bc..c06f3420e 100644
--- a/ca/django_ca/tests/acme/views/test_new_account.py
+++ b/ca/django_ca/tests/acme/views/test_new_account.py
@@ -15,7 +15,6 @@
# pylint: disable=redefined-outer-name # because of fixtures
-from collections.abc import Iterator
from http import HTTPStatus
from unittest import mock
@@ -53,19 +52,19 @@
@pytest.fixture
-def url() -> Iterator[str]:
+def url() -> str:
"""URL under test."""
return root_reverse("acme-new-account")
@pytest.fixture
-def message() -> Iterator[Registration]:
+def message() -> Registration:
"""Default message sent to the server."""
return Registration(contact=(CONTACT,), terms_of_service_agreed=True)
@pytest.fixture
-def kid() -> Iterator[None]:
+def kid() -> None:
"""Request requires no kid, yield None."""
return
@@ -228,7 +227,7 @@ def test_unsupported_contact(client: Client, url: str, root: CertificateAuthorit
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
('mailto:"with spaces"@example.com', "Quoted local part in email is not allowed."),
("mailto:user@example.com,user@example.net", "More than one addr-spec is not allowed."),
diff --git a/ca/django_ca/tests/acme/views/test_new_order.py b/ca/django_ca/tests/acme/views/test_new_order.py
index 984d30bcf..986ab4c49 100644
--- a/ca/django_ca/tests/acme/views/test_new_order.py
+++ b/ca/django_ca/tests/acme/views/test_new_order.py
@@ -15,7 +15,6 @@
# pylint: disable=redefined-outer-name # because of fixtures
-from collections.abc import Iterator
from datetime import timedelta, timezone as tz
from http import HTTPStatus
from typing import Any
@@ -46,13 +45,13 @@
@pytest.fixture
-def url() -> Iterator[str]:
+def url() -> str:
"""URL under test."""
return root_reverse("acme-new-order")
@pytest.fixture
-def message() -> Iterator[NewOrder]:
+def message() -> NewOrder:
"""Default message sent to the server."""
return NewOrder(identifiers=[{"type": "dns", "value": SERVER_NAME}])
@@ -187,7 +186,7 @@ def test_no_identifiers(client: Client, url: str, root: CertificateAuthority, ki
@pytest.mark.usefixtures("account")
@pytest.mark.parametrize(
- "values,expected",
+ ("values", "expected"),
(
({"not_before": now - timedelta(days=1)}, "Certificate cannot be valid before now."),
({"not_after": now + timedelta(days=3650)}, "Certificate cannot be valid that long."),
@@ -201,7 +200,6 @@ def test_invalid_not_before_after(
client: Client, url: str, root: CertificateAuthority, kid: str, values: dict[str, Any], expected: str
) -> None:
"""Test invalid not_before/not_after dates."""
- print(values)
message = NewOrder(identifiers=[{"type": "dns", "value": SERVER_NAME}], **values)
resp = acme_request(client, url, root, message, kid=kid)
assert_malformed(resp, root, expected)
diff --git a/ca/django_ca/tests/acme/views/test_order.py b/ca/django_ca/tests/acme/views/test_order.py
index c113c582b..3eafd57fd 100644
--- a/ca/django_ca/tests/acme/views/test_order.py
+++ b/ca/django_ca/tests/acme/views/test_order.py
@@ -15,7 +15,6 @@
# pylint: disable=redefined-outer-name # because of fixtures
-from collections.abc import Iterator
from http import HTTPStatus
from typing import Optional
from unittest import mock
@@ -51,13 +50,13 @@
@pytest.fixture
-def url(order: AcmeOrder) -> Iterator[str]:
+def url(order: AcmeOrder) -> str:
"""URL under test."""
return root_reverse("acme-order", slug=order.slug)
@pytest.fixture
-def message() -> Iterator[bytes]:
+def message() -> bytes:
"""Yield an empty bytestring, since this is a POST-AS-GET request."""
return b""
diff --git a/ca/django_ca/tests/acme/views/test_order_finalize.py b/ca/django_ca/tests/acme/views/test_order_finalize.py
index 0fc30e9f0..78d929e00 100644
--- a/ca/django_ca/tests/acme/views/test_order_finalize.py
+++ b/ca/django_ca/tests/acme/views/test_order_finalize.py
@@ -15,7 +15,6 @@
# pylint: disable=redefined-outer-name
-from collections.abc import Iterator
from http import HTTPStatus
from typing import Optional
from unittest import mock
@@ -64,7 +63,7 @@
@pytest.fixture
-def order(order: AcmeOrder) -> Iterator[AcmeOrder]:
+def order(order: AcmeOrder) -> AcmeOrder:
"""Override the module-level fixture to set the status to ready."""
order.status = AcmeOrder.STATUS_READY
order.save()
@@ -72,7 +71,7 @@ def order(order: AcmeOrder) -> Iterator[AcmeOrder]:
@pytest.fixture
-def authz(authz: AcmeAuthorization) -> Iterator[AcmeAuthorization]:
+def authz(authz: AcmeAuthorization) -> AcmeAuthorization:
"""Override the module-level fixture to set the status to valid."""
authz.status = AcmeAuthorization.STATUS_VALID
authz.save()
@@ -80,13 +79,13 @@ def authz(authz: AcmeAuthorization) -> Iterator[AcmeAuthorization]:
@pytest.fixture
-def url(order: AcmeOrder) -> Iterator[str]:
+def url(order: AcmeOrder) -> str:
"""URL under test."""
return root_reverse("acme-order-finalize", slug=order.slug)
@pytest.fixture
-def message() -> Iterator[CertificateRequest]:
+def message() -> CertificateRequest:
"""Default message sent to the server."""
req = X509Req.from_cryptography(CSR)
return CertificateRequest(csr=jose.util.ComparableX509(req))
diff --git a/ca/django_ca/tests/acme/views/test_revocation.py b/ca/django_ca/tests/acme/views/test_revocation.py
index 2945db3a6..172757d6a 100644
--- a/ca/django_ca/tests/acme/views/test_revocation.py
+++ b/ca/django_ca/tests/acme/views/test_revocation.py
@@ -16,7 +16,6 @@
# pylint: disable=redefined-outer-name # because of fixtures
import unittest
-from collections.abc import Iterator
from datetime import datetime
from http import HTTPStatus
from typing import Any, Optional, Union
@@ -58,13 +57,13 @@
@pytest.fixture
-def url() -> Iterator[str]:
+def url() -> str:
"""URL under test."""
return root_reverse("acme-revoke")
@pytest.fixture
-def message() -> Iterator[Revocation]:
+def message() -> Revocation:
"""Default message sent to the server."""
default_certificate = CERT_DATA["root-cert"]["pub"]["parsed"]
return Revocation(certificate=jose.util.ComparableX509(X509.from_cryptography(default_certificate)))
@@ -107,7 +106,7 @@ def acme(
return acme_request(client, url, ca, message, kid=kid)
@pytest.mark.parametrize(
- "use_tz, timestamp",
+ ("use_tz", "timestamp"),
((True, TIMESTAMPS["everything_valid"]), (False, TIMESTAMPS["everything_valid_naive"])),
)
def test_basic(
@@ -218,7 +217,7 @@ class TestAcmeCertificateRevocationWithAuthorizationsView(TestAcmeCertificateRev
CHILD_SLUG = acme_slug()
@pytest.fixture
- def child_kid_fixture(self, root: CertificateAuthority) -> Iterator[str]:
+ def child_kid_fixture(self, root: CertificateAuthority) -> str:
"""Fixture to set compute the child KID."""
return self.absolute_uri(":acme-account", serial=root.serial, slug=self.CHILD_SLUG)
@@ -287,7 +286,7 @@ def test_wrong_url(self) -> None: # type: ignore[override]
pass
@pytest.fixture
- def kid(self, child_kid_fixture: str) -> Iterator[Optional[str]]:
+ def kid(self, child_kid_fixture: str) -> Optional[str]:
"""Override kid to return the child kid."""
return child_kid_fixture
diff --git a/ca/django_ca/tests/acme/views/test_update_account.py b/ca/django_ca/tests/acme/views/test_update_account.py
index f7d649145..1a854299f 100644
--- a/ca/django_ca/tests/acme/views/test_update_account.py
+++ b/ca/django_ca/tests/acme/views/test_update_account.py
@@ -16,7 +16,6 @@
# pylint: disable=redefined-outer-name # because of fixtures
import unittest
-from collections.abc import Iterator
from http import HTTPStatus
from acme.messages import IDENTIFIER_FQDN, Identifier, Registration
@@ -39,13 +38,13 @@
@pytest.fixture
-def url(account_slug: str) -> Iterator[str]:
+def url(account_slug: str) -> str:
"""URL under test."""
return root_reverse("acme-account", slug=account_slug)
@pytest.fixture
-def message() -> Iterator[Registration]:
+def message() -> Registration:
"""Default message sent to the server."""
return Registration()
diff --git a/ca/django_ca/tests/acme/views/test_view_cert.py b/ca/django_ca/tests/acme/views/test_view_cert.py
index 6138418a5..70cea026b 100644
--- a/ca/django_ca/tests/acme/views/test_view_cert.py
+++ b/ca/django_ca/tests/acme/views/test_view_cert.py
@@ -15,7 +15,6 @@
# pylint: disable=redefined-outer-name # for to fixtures
-from collections.abc import Iterator
from http import HTTPStatus
from typing import Optional
@@ -37,7 +36,7 @@
@pytest.fixture
-def order(order: AcmeOrder) -> Iterator[AcmeOrder]:
+def order(order: AcmeOrder) -> AcmeOrder:
"""Override to set status to valid."""
order.status = AcmeOrder.STATUS_VALID
order.save()
@@ -45,13 +44,13 @@ def order(order: AcmeOrder) -> Iterator[AcmeOrder]:
@pytest.fixture
-def url(acme_cert_slug: str) -> Iterator[str]:
+def url(acme_cert_slug: str) -> str:
"""URL under test."""
return root_reverse("acme-cert", slug=acme_cert_slug)
@pytest.fixture
-def message() -> Iterator[bytes]:
+def message() -> bytes:
"""Yield an empty bytestring, since this is a POST-AS-GET request."""
return b""
diff --git a/ca/django_ca/tests/admin/base.py b/ca/django_ca/tests/admin/base.py
index afe5b7b06..893157f7d 100644
--- a/ca/django_ca/tests/admin/base.py
+++ b/ca/django_ca/tests/admin/base.py
@@ -60,16 +60,16 @@ def setUp(self) -> None:
def assertModified(self) -> None: # pylint: disable=invalid-name
"""Assert that the field was modified."""
- self.assertEqual(self.key_value_field.get_attribute("data-modified"), "true")
+ assert self.key_value_field.get_attribute("data-modified") == "true"
def assertNotModified(self) -> None: # pylint: disable=invalid-name
"""Assert that the field was not modified."""
- self.assertNotEqual(self.key_value_field.get_attribute("data-modified"), "true")
+ assert self.key_value_field.get_attribute("data-modified") != "true"
def assertChapterHasValue(self, chapter: WebElement, value: Any) -> None: # pylint: disable=invalid-name
"""Assert that the given chapter has the given value."""
loaded_value = json.loads(chapter.get_attribute("data-value")) # type: ignore[arg-type]
- self.assertEqual(loaded_value, value)
+ assert loaded_value == value
def initialize(self) -> None:
"""Load the page and find core elements.
@@ -94,7 +94,7 @@ def displayed_value(self) -> list[dict[str, str]]:
"""Load the currently displayed value from the key/value list."""
selects = self.key_value_list.find_elements(By.CSS_SELECTOR, "select")
inputs = self.key_value_list.find_elements(By.CSS_SELECTOR, "input")
- self.assertEqual(len(selects), len(inputs))
+ assert len(selects) == len(inputs)
return [
{
diff --git a/ca/django_ca/tests/admin/conftest.py b/ca/django_ca/tests/admin/conftest.py
index d5ab1d0e9..8b91e1974 100644
--- a/ca/django_ca/tests/admin/conftest.py
+++ b/ca/django_ca/tests/admin/conftest.py
@@ -13,8 +13,6 @@
"""Extra fixtures for tests for the admin interface."""
-from collections.abc import Iterator
-
from django.test import Client
from django.urls import reverse
@@ -25,13 +23,13 @@
@pytest.fixture(params=["name_to_rfc4514"])
-def extra_view_url(request: "SubRequest") -> Iterator[str]:
+def extra_view_url(request: "SubRequest") -> str:
"""Parametrized fixture providing reversed extra view URLs."""
return reverse(f"admin:django_ca_certificate_{request.param}")
@pytest.fixture
-def staff_client(user: "User", user_client: Client) -> Iterator[Client]:
+def staff_client(user: "User", user_client: Client) -> Client:
"""Client with a staff user with no extra permissions."""
user.is_staff = True
user.save()
diff --git a/ca/django_ca/tests/admin/test_actions.py b/ca/django_ca/tests/admin/test_actions.py
index 00dcdec8a..c5582c50e 100644
--- a/ca/django_ca/tests/admin/test_actions.py
+++ b/ca/django_ca/tests/admin/test_actions.py
@@ -39,7 +39,7 @@
from django_ca.models import Certificate, X509CertMixin
from django_ca.pydantic.general_name import GeneralNameModelList
from django_ca.signals import post_issue_cert, post_revoke_cert, pre_revoke_cert, pre_sign_cert
-from django_ca.tests.base.assertions import assert_revoked
+from django_ca.tests.base.assertions import assert_extension_equal, assert_revoked
from django_ca.tests.base.constants import TIMESTAMPS
from django_ca.tests.base.mixins import AdminTestCaseMixin
from django_ca.tests.base.mocks import mock_signal
@@ -80,7 +80,7 @@ def test_user_is_staff_only(self) -> None:
for obj in self.get_objects():
response = self.client.post(self.changelist_url, self.data)
- self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN)
+ assert response.status_code == HTTPStatus.FORBIDDEN
self.assertFailedRequest(response, obj)
def test_insufficient_permissions(self) -> None:
@@ -117,7 +117,7 @@ def test_insufficient_permissions(self) -> None:
for obj in self.get_objects():
response = self.client.post(self.changelist_url, self.data)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertFailedRequest(response, obj)
def test_required_permissions(self) -> None:
@@ -166,7 +166,7 @@ def assertForbidden( # pylint: disable=invalid-name
self, response: "HttpResponse", obj: Optional[DjangoCAModelTypeVar] = None
) -> None:
"""Assert that the action returned HTTP 403 (Forbidden)."""
- self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN)
+ assert response.status_code == HTTPStatus.FORBIDDEN
self.assertFailedRequest(response, obj=obj)
@contextmanager
@@ -206,7 +206,7 @@ def test_get(self) -> None:
for obj in self.get_objects():
with self.assertNoSignals():
response = self.client.get(self.get_url(obj=obj))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
def test_anonymous(self) -> None:
"""Test performing action as anonymous user."""
@@ -303,9 +303,9 @@ def assertFormValidationError( # pylint: disable=invalid-name
) -> None:
"""Assert that the form validation failed with the given errors."""
self.assertNotRevoked(cert)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertTemplateUsed("admin/django_ca/certificate/revoke_form.html")
- self.assertEqual(response.context["form"].errors, errors)
+ assert response.context["form"].errors == errors
def assertSuccessfulRequest(
self,
@@ -399,7 +399,7 @@ class ResignChangeActionTestCase(AdminChangeActionTestCaseMixin[Certificate], We
def assertFailedRequest(self, response: "HttpResponse", obj: Optional[Certificate] = None) -> None:
obj = obj or self.cert
- self.assertEqual(self.model.objects.filter(cn=obj.cn).count(), 1)
+ assert self.model.objects.filter(cn=obj.cn).count() == 1
def assertSuccessfulRequest(
self,
@@ -410,13 +410,13 @@ def assertSuccessfulRequest(
obj.refresh_from_db()
resigned = Certificate.objects.filter(cn=obj.cn).exclude(pk=obj.pk).get()
- self.assertFalse(resigned.revoked)
- self.assertFalse(obj.revoked)
- self.assertEqual(obj.cn, resigned.cn)
- self.assertEqual(obj.csr, resigned.csr)
- self.assertEqual(obj.profile, resigned.profile)
- self.assertEqual(obj.cn, resigned.cn)
- self.assertEqual(obj.algorithm, resigned.algorithm)
+ assert not resigned.revoked
+ assert not obj.revoked
+ assert obj.cn == resigned.cn
+ assert obj.csr == resigned.csr
+ assert obj.profile == resigned.profile
+ assert obj.cn == resigned.cn
+ assert obj.algorithm == resigned.algorithm
for oid in [
ExtensionOID.EXTENDED_KEY_USAGE,
@@ -424,11 +424,11 @@ def assertSuccessfulRequest(
ExtensionOID.KEY_USAGE,
ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
]:
- self.assertEqual(obj.extensions.get(oid), resigned.extensions.get(oid))
+ assert_extension_equal(obj.extensions.get(oid), resigned.extensions.get(oid))
# Some properties are obviously *not* equal
- self.assertNotEqual(obj.pub, resigned.pub)
- self.assertNotEqual(obj.serial, resigned.serial)
+ assert obj.pub != resigned.pub
+ assert obj.serial != resigned.serial
@property
def data(self) -> dict[str, Any]: # type: ignore[override]
@@ -488,7 +488,7 @@ def test_no_profile(self) -> None:
form.submit().follow()
resigned = Certificate.objects.filter(cn=self.cert.cn).exclude(pk=self.cert.pk).get()
- self.assertEqual(resigned.profile, model_settings.CA_DEFAULT_PROFILE)
+ assert resigned.profile == model_settings.CA_DEFAULT_PROFILE
@override_tmpcadir()
def test_webtest_basic(self) -> None:
diff --git a/ca/django_ca/tests/admin/test_add_cert.py b/ca/django_ca/tests/admin/test_add_cert.py
index eb0d0d53e..3c1584469 100644
--- a/ca/django_ca/tests/admin/test_add_cert.py
+++ b/ca/django_ca/tests/admin/test_add_cert.py
@@ -55,6 +55,7 @@
from django_ca.tests.admin.base import AddCertificateSeleniumTestCase, CertificateModelAdminTestCaseMixin
from django_ca.tests.base.assertions import (
assert_authority_key_identifier,
+ assert_count_equal,
assert_create_cert_signals,
assert_extensions,
assert_post_issue_cert,
@@ -174,16 +175,13 @@ def add_cert(self, cname: str, ca: CertificateAuthority, algorithm: str = "SHA-2
cert = Certificate.objects.get(cn=cname)
assert_post_issue_cert(post, cert)
- self.assertEqual(
- cert.pub.loaded.subject,
- x509.Name(
- [
- x509.NameAttribute(oid=NameOID.COUNTRY_NAME, value="US"),
- x509.NameAttribute(oid=NameOID.COMMON_NAME, value=cname),
- ]
- ),
+ assert cert.pub.loaded.subject == x509.Name(
+ [
+ x509.NameAttribute(oid=NameOID.COUNTRY_NAME, value="US"),
+ x509.NameAttribute(oid=NameOID.COMMON_NAME, value=cname),
+ ]
)
- self.assertIssuer(ca, cert)
+ assert cert.issuer == ca.subject
assert_extensions(
cert,
[
@@ -203,24 +201,24 @@ def add_cert(self, cname: str, ca: CertificateAuthority, algorithm: str = "SHA-2
),
],
)
- self.assertEqual(cert.ca, ca)
- self.assertEqual(cert.csr.pem, CSR)
- self.assertEqual(cert.profile, "webserver")
+ assert cert.ca == ca
+ assert cert.csr.pem == CSR
+ assert cert.profile == "webserver"
# Some extensions are NOT set
- self.assertNotIn(ExtensionOID.ISSUER_ALTERNATIVE_NAME, cert.extensions)
+ assert ExtensionOID.ISSUER_ALTERNATIVE_NAME not in cert.extensions
# Test that we can view the certificate
response = self.client.get(cert.admin_change_url)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
def _test_get(self) -> "HttpResponse":
"""Do a basic get request (to test CSS etc)."""
response = self.client.get(self.add_url)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
templates = [t.name for t in response.templates]
- self.assertIn("admin/django_ca/certificate/change_form.html", templates)
- self.assertIn("admin/change_form.html", templates)
+ assert "admin/django_ca/certificate/change_form.html" in templates
+ assert "admin/change_form.html" in templates
assert_css(response, "django_ca/admin/css/base.css")
assert_css(response, "django_ca/admin/css/certificateadmin.css")
return response
@@ -247,14 +245,14 @@ def test_default_ca_key_does_not_exist(self) -> None:
"""View add form when the ca key does not exist."""
storages["django-ca"].delete(self.ca.key_backend_options["path"])
response = self.client.get(self.add_url)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
form = response.context_data["adminform"].form # type: ignore[attr-defined] # false positive
field = form.fields["ca"]
bound_field = field.get_bound_field(form, "ca")
- self.assertNotEqual(bound_field.initial, self.ca)
- self.assertIsInstance(bound_field.initial, CertificateAuthority)
+ assert bound_field.initial != self.ca
+ assert isinstance(bound_field.initial, CertificateAuthority)
@override_tmpcadir(CA_DEFAULT_CA=CERT_DATA["child"]["serial"])
def test_cas_expired(self) -> None:
@@ -284,7 +282,7 @@ def test_get_profiles(self) -> None:
field = form.fields["ocsp_no_check"]
bound_field = field.get_bound_field(form, "ocsp_no_check")
- self.assertEqual(bound_field.initial, ocsp_no_check(critical=True))
+ assert bound_field.initial == ocsp_no_check(critical=True)
@override_tmpcadir(CA_DEFAULT_SUBJECT=tuple())
def test_add(self) -> None:
@@ -306,10 +304,9 @@ def test_empty_subject(self) -> None:
cert: Certificate = Certificate.objects.get(cn="")
assert_post_issue_cert(post, cert)
- self.assertEqual(cert.subject, x509.Name([]))
- self.assertEqual(
- cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME],
- subject_alternative_name(dns(self.hostname)),
+ assert cert.subject == x509.Name([])
+ assert cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME] == subject_alternative_name(
+ dns(self.hostname)
)
@override_tmpcadir(CA_DEFAULT_SUBJECT=tuple())
@@ -335,16 +332,13 @@ def test_subject_with_multiple_org_units(self) -> None:
cert: Certificate = Certificate.objects.get(cn=self.hostname)
assert_post_issue_cert(post, cert)
- self.assertEqual(
- cert.subject,
- x509.Name(
- [
- x509.NameAttribute(oid=NameOID.COUNTRY_NAME, value="US"),
- x509.NameAttribute(oid=NameOID.ORGANIZATIONAL_UNIT_NAME, value="OU-1"),
- x509.NameAttribute(oid=NameOID.ORGANIZATIONAL_UNIT_NAME, value="OU-2"),
- x509.NameAttribute(oid=NameOID.COMMON_NAME, value=self.hostname),
- ]
- ),
+ assert cert.subject == x509.Name(
+ [
+ x509.NameAttribute(oid=NameOID.COUNTRY_NAME, value="US"),
+ x509.NameAttribute(oid=NameOID.ORGANIZATIONAL_UNIT_NAME, value="OU-1"),
+ x509.NameAttribute(oid=NameOID.ORGANIZATIONAL_UNIT_NAME, value="OU-2"),
+ x509.NameAttribute(oid=NameOID.COMMON_NAME, value=self.hostname),
+ ]
)
@override_tmpcadir(CA_DEFAULT_SUBJECT=tuple())
@@ -363,17 +357,14 @@ def test_add_no_common_name_and_no_subject_alternative_name(self) -> None:
"subject_alternative_name_1": True,
},
)
- self.assertEqual(response.status_code, HTTPStatus.OK)
- self.assertFalse(response.context["adminform"].form.is_valid())
- self.assertEqual(
- response.context["adminform"].form.errors,
- {
- "subject_alternative_name": [
- "Subject Alternative Name is required if the subject does not contain a Common Name."
- ]
- },
- )
- self.assertEqual(cert_count, Certificate.objects.all().count())
+ assert response.status_code == HTTPStatus.OK
+ assert not response.context["adminform"].form.is_valid()
+ assert response.context["adminform"].form.errors == {
+ "subject_alternative_name": [
+ "Subject Alternative Name is required if the subject does not contain a Common Name."
+ ]
+ }
+ assert cert_count == Certificate.objects.all().count()
@override_tmpcadir(CA_DEFAULT_SUBJECT=tuple())
def test_subject_with_multiple_country_codes(self) -> None:
@@ -394,10 +385,10 @@ def test_subject_with_multiple_country_codes(self) -> None:
),
},
)
- self.assertFalse(response.context["adminform"].form.is_valid())
+ assert not response.context["adminform"].form.is_valid()
msg = "Value error, attribute of type countryName must not occur more then once in a name."
- self.assertEqual(response.context["adminform"].form.errors, {"subject": [msg]})
+ assert response.context["adminform"].form.errors == {"subject": [msg]}
@override_tmpcadir(CA_DEFAULT_SUBJECT=tuple())
def test_subject_with_invalid_country_code(self) -> None:
@@ -417,12 +408,11 @@ def test_subject_with_invalid_country_code(self) -> None:
),
},
)
- self.assertEqual(response.status_code, HTTPStatus.OK)
- self.assertFalse(response.context["adminform"].form.is_valid())
- self.assertEqual(
- response.context["adminform"].form.errors,
- {"subject": ["Value error, FOO: Must have exactly two characters"]},
- )
+ assert response.status_code == HTTPStatus.OK
+ assert not response.context["adminform"].form.is_valid()
+ assert response.context["adminform"].form.errors == {
+ "subject": ["Value error, FOO: Must have exactly two characters"]
+ }
@override_tmpcadir(CA_DEFAULT_SUBJECT=tuple())
def test_add_no_key_usage(self) -> None:
@@ -437,11 +427,11 @@ def test_add_no_key_usage(self) -> None:
self.assertRedirects(response, self.changelist_url)
cert = Certificate.objects.get(cn=self.hostname)
- self.assertNotIn(ExtensionOID.KEY_USAGE, cert.extensions) # KeyUsage is not set!
+ assert ExtensionOID.KEY_USAGE not in cert.extensions # KeyUsage is not set!
# Test that we can view the certificate
response = self.client.get(cert.admin_change_url)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
@override_tmpcadir(CA_DEFAULT_SUBJECT=tuple(), CA_PASSWORDS={})
def test_add_with_password(self) -> None:
@@ -531,14 +521,13 @@ def test_invalid_csr(self) -> None:
with assert_create_cert_signals(False, False):
response = self.client.post(self.add_url, data=self.form_data("whatever", ca))
- self.assertEqual(response.status_code, HTTPStatus.OK)
- self.assertFalse(response.context["adminform"].form.is_valid())
- self.assertEqual(
- response.context["adminform"].form.errors,
- {"csr": [CertificateSigningRequestField.simple_validation_error]},
- )
+ assert response.status_code == HTTPStatus.OK
+ assert not response.context["adminform"].form.is_valid()
+ assert response.context["adminform"].form.errors == {
+ "csr": [CertificateSigningRequestField.simple_validation_error]
+ }
- with self.assertRaises(Certificate.DoesNotExist):
+ with pytest.raises(Certificate.DoesNotExist):
Certificate.objects.get(cn=self.hostname)
@override_tmpcadir()
@@ -557,16 +546,16 @@ def test_unparsable_csr(self) -> None:
"-----BEGIN CERTIFICATE REQUEST-----\nwrong-----END CERTIFICATE REQUEST-----", ca
),
)
- self.assertEqual(response.status_code, HTTPStatus.OK, response.content)
- self.assertFalse(response.context["adminform"].form.is_valid())
+ assert response.status_code == HTTPStatus.OK, response.content
+ assert not response.context["adminform"].form.is_valid()
# Not testing exact error message here, as it the one from cryptography. Instead, just check that
# there is exactly one message for the "csr" field.
form = response.context["adminform"].form
- self.assertEqual(len(form.errors), 1, form.errors)
- self.assertEqual(len(form.errors["csr"]), 1, form.errors["csr"])
+ assert len(form.errors) == 1, form.errors
+ assert len(form.errors["csr"]) == 1, form.errors["csr"]
- with self.assertRaises(Certificate.DoesNotExist):
+ with pytest.raises(Certificate.DoesNotExist):
Certificate.objects.get(cn=self.hostname)
@override_tmpcadir()
@@ -579,15 +568,14 @@ def test_not_after_in_the_past(self) -> None:
response = self.client.post(
self.add_url, data={**self.form_data(CSR, ca), "not_after": expires.strftime("%Y-%m-%d")}
)
- self.assertEqual(response.status_code, HTTPStatus.OK)
- self.assertIn("Certificate cannot expire in the past.", response.content.decode("utf-8"))
- self.assertFalse(response.context["adminform"].form.is_valid())
- self.assertEqual(
- response.context["adminform"].form.errors,
- {"not_after": ["Certificate cannot expire in the past."]},
- )
+ assert response.status_code == HTTPStatus.OK
+ assert "Certificate cannot expire in the past." in response.content.decode("utf-8")
+ assert not response.context["adminform"].form.is_valid()
+ assert response.context["adminform"].form.errors == {
+ "not_after": ["Certificate cannot expire in the past."]
+ }
- with self.assertRaises(Certificate.DoesNotExist):
+ with pytest.raises(Certificate.DoesNotExist):
Certificate.objects.get(cn=self.hostname)
@override_tmpcadir()
@@ -602,12 +590,12 @@ def test_expires_too_late(self) -> None:
response = self.client.post(
self.add_url, data={**self.form_data(CSR, ca), "not_after": expires.strftime("%Y-%m-%d")}
)
- self.assertEqual(response.status_code, HTTPStatus.OK)
- self.assertIn(error, response.content.decode("utf-8"))
- self.assertFalse(response.context["adminform"].form.is_valid())
- self.assertEqual(response.context["adminform"].form.errors, {"not_after": [error]})
+ assert response.status_code == HTTPStatus.OK
+ assert error in response.content.decode("utf-8")
+ assert not response.context["adminform"].form.is_valid()
+ assert response.context["adminform"].form.errors == {"not_after": [error]}
- with self.assertRaises(Certificate.DoesNotExist):
+ with pytest.raises(Certificate.DoesNotExist):
Certificate.objects.get(cn=self.hostname)
@override_tmpcadir()
@@ -629,11 +617,10 @@ def test_invalid_signature_hash_algorithm(self) -> None:
"not_after": self.default_expires,
},
)
- self.assertFalse(response.context["adminform"].form.is_valid(), response)
- self.assertEqual(
- response.context["adminform"].form.errors,
- {"algorithm": ["Ed448-based certificate authorities do not use a signature hash algorithm."]},
- )
+ assert not response.context["adminform"].form.is_valid(), response
+ assert response.context["adminform"].form.errors == {
+ "algorithm": ["Ed448-based certificate authorities do not use a signature hash algorithm."]
+ }
# Test with Ed25519 CA
csr = CERT_DATA["ed25519-cert"]["csr"]["parsed"].public_bytes(Encoding.PEM).decode("utf-8")
@@ -651,11 +638,10 @@ def test_invalid_signature_hash_algorithm(self) -> None:
"not_after": self.default_expires,
},
)
- self.assertFalse(response.context["adminform"].form.is_valid(), response)
- self.assertEqual(
- response.context["adminform"].form.errors,
- {"algorithm": ["Ed25519-based certificate authorities do not use a signature hash algorithm."]},
- )
+ assert not response.context["adminform"].form.is_valid(), response
+ assert response.context["adminform"].form.errors == {
+ "algorithm": ["Ed25519-based certificate authorities do not use a signature hash algorithm."]
+ }
# Test with DSA CA
csr = CERT_DATA["dsa-cert"]["csr"]["parsed"].public_bytes(Encoding.PEM).decode("utf-8")
@@ -673,11 +659,10 @@ def test_invalid_signature_hash_algorithm(self) -> None:
"not_after": self.default_expires,
},
)
- self.assertFalse(response.context["adminform"].form.is_valid(), response)
- self.assertEqual(
- response.context["adminform"].form.errors,
- {"algorithm": ["DSA-based certificate authorities require a SHA-256 signature hash algorithm."]},
- )
+ assert not response.context["adminform"].form.is_valid(), response
+ assert response.context["adminform"].form.errors == {
+ "algorithm": ["DSA-based certificate authorities require a SHA-256 signature hash algorithm."]
+ }
# Test with RSA CA
with assert_create_cert_signals(False, False):
@@ -694,11 +679,10 @@ def test_invalid_signature_hash_algorithm(self) -> None:
"not_after": self.default_expires,
},
)
- self.assertFalse(response.context["adminform"].form.is_valid(), response)
- self.assertEqual(
- response.context["adminform"].form.errors,
- {"algorithm": ["RSA-based certificate authorities require a signature hash algorithm."]},
- )
+ assert not response.context["adminform"].form.is_valid(), response
+ assert response.context["adminform"].form.errors == {
+ "algorithm": ["RSA-based certificate authorities require a signature hash algorithm."]
+ }
@override_tmpcadir(CA_DEFAULT_SUBJECT=tuple())
def test_certificate_policies_with_invalid_oid(self) -> None:
@@ -725,13 +709,12 @@ def test_certificate_policies_with_invalid_oid(self) -> None:
"certificate_policies_0": "abc",
},
)
- self.assertEqual(response.status_code, HTTPStatus.OK)
- self.assertFalse(response.context["adminform"].form.is_valid())
- self.assertEqual(
- response.context["adminform"].form.errors,
- {"certificate_policies": ["abc: The given OID is invalid."]},
- )
- self.assertEqual(cert_count, Certificate.objects.all().count())
+ assert response.status_code == HTTPStatus.OK
+ assert not response.context["adminform"].form.is_valid()
+ assert response.context["adminform"].form.errors == {
+ "certificate_policies": ["abc: The given OID is invalid."]
+ }
+ assert cert_count == Certificate.objects.all().count()
def test_add_no_cas(self) -> None:
"""Test adding when all CAs are disabled."""
@@ -843,18 +826,18 @@ def assertProfile( # pylint: disable=invalid-name
ku_expected = self.get_expected(profile, ExtensionOID.KEY_USAGE, [])
ku_selected = [o.get_attribute("value") for o in ku_select.all_selected_options]
- self.assertCountEqual(ku_expected["value"], ku_selected)
- self.assertEqual(ku_expected["critical"], ku_critical.is_selected())
+ assert_count_equal(ku_expected["value"], ku_selected)
+ assert ku_expected["critical"] == ku_critical.is_selected()
eku_expected = self.get_expected(profile, ExtensionOID.EXTENDED_KEY_USAGE, [])
eku_selected = [o.get_attribute("value") for o in eku_select.all_selected_options]
- self.assertCountEqual(eku_expected["value"], eku_selected)
- self.assertEqual(eku_expected["critical"], eku_critical.is_selected())
+ assert_count_equal(eku_expected["value"], eku_selected)
+ assert eku_expected["critical"] == eku_critical.is_selected()
tf_selected = [o.get_attribute("value") for o in tf_select.all_selected_options]
tf_expected = self.get_expected(profile, ExtensionOID.TLS_FEATURE, [])
- self.assertCountEqual(tf_expected.get("value", []), tf_selected)
- self.assertEqual(tf_expected.get("critical", False), tf_critical.is_selected())
+ assert_count_equal(tf_expected.get("value", []), tf_selected)
+ assert tf_expected.get("critical", False) == tf_critical.is_selected()
def clear_form(
self,
@@ -901,10 +884,9 @@ def test_select_profile(self) -> None:
tf_critical = self.find("input#id_tls_feature_1")
# test that the default profile is preselected
- self.assertEqual(
- [model_settings.CA_DEFAULT_PROFILE],
- [o.get_attribute("value") for o in select.all_selected_options],
- )
+ assert [model_settings.CA_DEFAULT_PROFILE] == [
+ o.get_attribute("value") for o in select.all_selected_options
+ ]
# assert that the values from the default profile are preloaded
self.assertProfile(
@@ -993,22 +975,22 @@ def test_subject_field(self) -> None:
]
# Test the initial state
- self.assertEqual(self.value, expected_initial_subject)
- self.assertEqual(self.displayed_value, expected_initial_subject)
+ assert self.value == expected_initial_subject
+ assert self.displayed_value == expected_initial_subject
# Add a row and confirm that it's initially empty and the field is thus not yet modified
self.key_value_field.find_element(By.CLASS_NAME, "add-row-btn").click()
self.assertNotModified()
new_select = Select(self.key_value_list.find_elements(By.CSS_SELECTOR, "select")[-1])
new_input = self.key_value_list.find_elements(By.CSS_SELECTOR, "input")[-1]
- self.assertEqual(new_select.all_selected_options, [])
- self.assertEqual(new_input.get_attribute("value"), "")
+ assert new_select.all_selected_options == []
+ assert new_input.get_attribute("value") == ""
# Enter a value. This marks the field as modified, but the hidden input is *not* updated, as there is
# no key/OID selected yet
new_input.send_keys(self.hostname)
self.assertModified()
- self.assertEqual(self.value, expected_initial_subject)
+ assert self.value == expected_initial_subject
# Now select common name, and the subject is also updated
new_select.select_by_value(NameOID.COMMON_NAME.dotted_string)
@@ -1016,15 +998,15 @@ def test_subject_field(self) -> None:
*expected_initial_subject,
{"oid": NameOID.COMMON_NAME.dotted_string, "value": self.hostname},
]
- self.assertEqual(self.value, new_subject)
- self.assertEqual(self.displayed_value, new_subject) # just to be sure
+ assert self.value == new_subject
+ assert self.displayed_value == new_subject # just to be sure
# Remove the second row, check the update
self.key_value_list.find_elements(By.CSS_SELECTOR, ".remove-row-btn")[1].click()
new_subject.pop(1)
- self.assertEqual(len(new_subject), 2)
- self.assertEqual(self.value, new_subject)
- self.assertEqual(self.displayed_value, new_subject)
+ assert len(new_subject) == 2
+ assert self.value == new_subject
+ assert self.displayed_value == new_subject
@override_tmpcadir()
def test_csr_integration(self) -> None:
@@ -1123,16 +1105,16 @@ def test_paste_csr_no_subject(self) -> None:
)
# Check that the right parts of the CSR chapter is displayed
- self.assertIs(no_csr.is_displayed(), False)
- self.assertIs(has_content.is_displayed(), False)
- self.assertIs(no_content.is_displayed(), True)
+ assert no_csr.is_displayed() is False
+ assert has_content.is_displayed() is False
+ assert no_content.is_displayed() is True
self.assertNotModified()
# Click the clear button and validate that the subject is cleared
csr_chapter.find_element(By.CSS_SELECTOR, ".clear-button").click()
self.assertModified()
- self.assertEqual(self.value, [])
- self.assertEqual(self.displayed_value, [])
+ assert self.value == []
+ assert self.displayed_value == []
@override_tmpcadir()
def test_paste_csr_missing_delimiters(self) -> None:
@@ -1152,9 +1134,9 @@ def test_paste_csr_missing_delimiters(self) -> None:
csr_field.send_keys(csr.public_bytes(Encoding.PEM).decode("ascii")[1:])
# Check that the right parts of the CSR chapter is displayed
- self.assertIs(no_csr.is_displayed(), True) # this is displayed as we haven't pasted a CSR
- self.assertIs(has_content.is_displayed(), False)
- self.assertIs(no_content.is_displayed(), False)
+ assert no_csr.is_displayed() is True # this is displayed as we haven't pasted a CSR
+ assert has_content.is_displayed() is False
+ assert no_content.is_displayed() is False
self.assertNotModified()
@override_tmpcadir()
@@ -1174,9 +1156,9 @@ def test_paste_invalid_csr(self) -> None:
csr.send_keys("-----BEGIN CERTIFICATE REQUEST-----\nXXX\n-----END CERTIFICATE REQUEST-----")
# Check that the right parts of the CSR chapter is displayed
- self.assertIs(no_csr.is_displayed(), True) # this is displayed as we haven't pasted a CSR
- self.assertIs(has_content.is_displayed(), False)
- self.assertIs(no_content.is_displayed(), False)
+ assert no_csr.is_displayed() is True # this is displayed as we haven't pasted a CSR
+ assert has_content.is_displayed() is False
+ assert no_content.is_displayed() is False
self.assertNotModified()
@override_tmpcadir(
@@ -1224,10 +1206,10 @@ def test_profile_integration(self) -> None:
# Test the initial state (webserver subject, since it's the default profile
self.assertNotModified()
self.assertChapterHasValue(chapter, webserver_subject)
- self.assertEqual(self.value, webserver_subject)
- self.assertEqual(self.displayed_value, webserver_subject)
- self.assertIs(has_content.is_displayed(), True)
- self.assertIs(no_content.is_displayed(), False)
+ assert self.value == webserver_subject
+ assert self.displayed_value == webserver_subject
+ assert has_content.is_displayed() is True
+ assert no_content.is_displayed() is False
profile_select = Select(self.selenium.find_element(By.ID, "id_profile"))
@@ -1235,10 +1217,10 @@ def test_profile_integration(self) -> None:
profile_select.select_by_value("client")
self.assertNotModified()
self.assertChapterHasValue(chapter, client_subject)
- self.assertEqual(self.value, client_subject)
- self.assertEqual(self.displayed_value, client_subject)
- self.assertIs(has_content.is_displayed(), True)
- self.assertIs(no_content.is_displayed(), False)
+ assert self.value == client_subject
+ assert self.displayed_value == client_subject
+ assert has_content.is_displayed() is True
+ assert no_content.is_displayed() is False
# Change one field and check modification
st_input = self.key_value_list.find_elements(By.CSS_SELECTOR, "input")[1]
@@ -1247,24 +1229,24 @@ def test_profile_integration(self) -> None:
new_subject = deepcopy(client_subject)
new_subject[1]["value"] = "Styria"
self.assertModified()
- self.assertEqual(self.value, new_subject)
- self.assertEqual(self.displayed_value, new_subject)
+ assert self.value == new_subject
+ assert self.displayed_value == new_subject
# Switch back to the old profile. Since you made changes, it's not automatically updated
profile_select.select_by_value("webserver")
self.assertModified()
self.assertChapterHasValue(chapter, webserver_subject)
- self.assertEqual(self.value, new_subject)
- self.assertEqual(self.displayed_value, new_subject)
+ assert self.value == new_subject
+ assert self.displayed_value == new_subject
# Copy the profile subject and check the state
chapter.find_element(By.CLASS_NAME, "copy-button").click()
self.assertNotModified()
self.assertChapterHasValue(chapter, webserver_subject)
- self.assertEqual(self.value, webserver_subject)
- self.assertEqual(self.displayed_value, webserver_subject)
- self.assertIs(has_content.is_displayed(), True)
- self.assertIs(no_content.is_displayed(), False)
+ assert self.value == webserver_subject
+ assert self.displayed_value == webserver_subject
+ assert has_content.is_displayed() is True
+ assert no_content.is_displayed() is False
# Modify subject again (so that we can check the modified flag of the clear button)
st_input = self.key_value_list.find_elements(By.CSS_SELECTOR, "input")[1]
@@ -1272,24 +1254,24 @@ def test_profile_integration(self) -> None:
st_input.send_keys("Styria")
new_subject[2]["value"] = "webserver"
self.assertModified()
- self.assertEqual(self.value, new_subject)
- self.assertEqual(self.displayed_value, new_subject)
+ assert self.value == new_subject
+ assert self.displayed_value == new_subject
# Switch to the profile with no subject and check the state
profile_select.select_by_value("no-subject")
self.assertModified()
self.assertChapterHasValue(chapter, [])
- self.assertEqual(self.value, new_subject)
- self.assertEqual(self.displayed_value, new_subject)
- self.assertIs(has_content.is_displayed(), False)
- self.assertIs(no_content.is_displayed(), True)
+ assert self.value == new_subject
+ assert self.displayed_value == new_subject
+ assert has_content.is_displayed() is False
+ assert no_content.is_displayed() is True
# Click the clear button
chapter.find_element(By.CLASS_NAME, "clear-button").click()
self.assertNotModified()
self.assertChapterHasValue(chapter, [])
- self.assertEqual(self.value, [])
- self.assertEqual(self.displayed_value, [])
+ assert self.value == []
+ assert self.displayed_value == []
@freeze_time(TIMESTAMPS["everything_valid"])
@@ -1318,7 +1300,7 @@ def test_empty_form_and_empty_cert(self) -> None:
form["authority_information_access_0"] = "[]"
form["authority_information_access_1"] = "[]"
response = form.submit()
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
# Fill in the bare minimum fields
form = response.forms["certificate_form"]
@@ -1331,18 +1313,15 @@ def test_empty_form_and_empty_cert(self) -> None:
# Submit the form
response = form.submit().follow()
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
cert = Certificate.objects.get(cn="test-empty-form.example.com")
# Cert has minimal extensions, since we cleared the form earlier
- self.assertEqual(
- cert.sorted_extensions,
- [
- cert.ca.get_authority_key_identifier_extension(),
- basic_constraints(),
- subject_key_identifier(cert),
- ],
- )
+ assert cert.sorted_extensions == [
+ cert.ca.get_authority_key_identifier_extension(),
+ basic_constraints(),
+ subject_key_identifier(cert),
+ ]
@override_tmpcadir(
CA_PROFILES={
@@ -1361,20 +1340,17 @@ def test_none_extension_and_subject_alternative_name_extension(self) -> None:
form["csr"] = CSR
form["subject"] = json.dumps([{"oid": NameOID.COMMON_NAME.dotted_string, "value": self.hostname}])
response = form.submit().follow()
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
cert: Certificate = Certificate.objects.get(cn=self.hostname)
- self.assertEqual(
- cert.sorted_extensions,
- [
- cert.ca.sign_authority_information_access,
- cert.ca.get_authority_key_identifier_extension(),
- basic_constraints(),
- cert.ca.sign_crl_distribution_points,
- subject_alternative_name(dns("example.com")),
- subject_key_identifier(cert),
- ],
- )
+ assert cert.sorted_extensions == [
+ cert.ca.sign_authority_information_access,
+ cert.ca.get_authority_key_identifier_extension(),
+ basic_constraints(),
+ cert.ca.sign_crl_distribution_points,
+ subject_alternative_name(dns("example.com")),
+ subject_key_identifier(cert),
+ ]
@override_tmpcadir(CA_PROFILES={"nothing": {}}, CA_DEFAULT_PROFILE="nothing")
def test_only_ca_prefill(self) -> None:
@@ -1516,8 +1492,8 @@ def test_full_profile_prefill(self) -> None:
fields would not show up in the signed certificate.
"""
# Make sure that the CA has sign_* field values set.
- self.assertIsNotNone(self.ca.sign_authority_information_access)
- self.assertIsNotNone(self.ca.sign_crl_distribution_points)
+ assert self.ca.sign_authority_information_access is not None
+ assert self.ca.sign_crl_distribution_points is not None
self.ca.sign_certificate_policies = certificate_policies(
x509.PolicyInformation(
policy_identifier=CertificatePoliciesOID.CPS_QUALIFIER, policy_qualifiers=None
@@ -1536,56 +1512,50 @@ def test_full_profile_prefill(self) -> None:
form["csr"] = CSR
form["subject"] = json.dumps([{"oid": NameOID.COMMON_NAME.dotted_string, "value": self.hostname}])
response = form.submit().follow()
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
# Check that we get all the extensions from the CA
cert = Certificate.objects.get(cn=self.hostname)
- self.assertEqual(cert.profile, "everything")
- self.assertEqual(
- cert.sorted_extensions,
- [
- authority_information_access(
- ca_issuers=[uri("http://profile.issuers.example.com")],
- ocsp=[
- uri("http://profile.ocsp.example.com"),
- uri("http://profile.ocsp-backup.example.com"),
- ],
- critical=False,
- ),
- cert.ca.get_authority_key_identifier_extension(),
- basic_constraints(),
- crl_distribution_points(
- distribution_point(
- full_name=[uri("http://crl.profile.example.com")],
- crl_issuer=[uri("http://crl-issuer.profile.example.com")],
- ),
- critical=True,
- ),
- certificate_policies(
- x509.PolicyInformation(
- policy_identifier=CertificatePoliciesOID.CPS_USER_NOTICE, policy_qualifiers=["text1"]
- ),
- critical=True,
- ),
- extended_key_usage(
- ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, critical=True
+ assert cert.profile == "everything"
+ assert cert.sorted_extensions == [
+ authority_information_access(
+ ca_issuers=[uri("http://profile.issuers.example.com")],
+ ocsp=[uri("http://profile.ocsp.example.com"), uri("http://profile.ocsp-backup.example.com")],
+ critical=False,
+ ),
+ cert.ca.get_authority_key_identifier_extension(),
+ basic_constraints(),
+ crl_distribution_points(
+ distribution_point(
+ full_name=[uri("http://crl.profile.example.com")],
+ crl_issuer=[uri("http://crl-issuer.profile.example.com")],
),
- freshest_crl(
- distribution_point(
- full_name=[uri("http://freshest-crl.profile.example.com")],
- crl_issuer=[uri("http://freshest-crl-issuer.profile.example.com")],
- ),
- critical=False,
+ critical=True,
+ ),
+ certificate_policies(
+ x509.PolicyInformation(
+ policy_identifier=CertificatePoliciesOID.CPS_USER_NOTICE, policy_qualifiers=["text1"]
),
- issuer_alternative_name(
- uri("http://ian1.example.com"), uri("http://ian2.example.com"), critical=True
+ critical=True,
+ ),
+ extended_key_usage(
+ ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, critical=True
+ ),
+ freshest_crl(
+ distribution_point(
+ full_name=[uri("http://freshest-crl.profile.example.com")],
+ crl_issuer=[uri("http://freshest-crl-issuer.profile.example.com")],
),
- key_usage(key_agreement=True, key_cert_sign=True),
- ocsp_no_check(critical=True),
- subject_key_identifier(cert),
- tls_feature(x509.TLSFeatureType.status_request, critical=True),
- ],
- )
+ critical=False,
+ ),
+ issuer_alternative_name(
+ uri("http://ian1.example.com"), uri("http://ian2.example.com"), critical=True
+ ),
+ key_usage(key_agreement=True, key_cert_sign=True),
+ ocsp_no_check(critical=True),
+ subject_key_identifier(cert),
+ tls_feature(x509.TLSFeatureType.status_request, critical=True),
+ ]
@override_tmpcadir(
CA_PROFILES={
@@ -1640,13 +1610,10 @@ def test_multiple_distribution_points(self) -> None:
with self.assertLogs("django_ca") as logcm:
response = self.app.get(self.add_url, user=self.user.username)
- self.assertEqual(
- logcm.output,
- [
- "WARNING:django_ca.widgets:Received multiple DistributionPoints, only the first can be "
- "changed in the web interface."
- ],
- )
+ assert logcm.output == [
+ "WARNING:django_ca.widgets:Received multiple DistributionPoints, only the first can be "
+ "changed in the web interface."
+ ]
form = response.forms["certificate_form"]
# default value for form field is on import time, so override settings does not change
# profile field
@@ -1655,41 +1622,38 @@ def test_multiple_distribution_points(self) -> None:
form["subject"] = json.dumps([{"oid": NameOID.COMMON_NAME.dotted_string, "value": cn}])
response = form.submit()
response = response.follow()
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
# Check that we get all the extensions from the CA
cert: Certificate = Certificate.objects.get(cn="test-only-ca.example.com")
- self.assertEqual(cert.profile, "everything")
- self.assertEqual(
- cert.sorted_extensions,
- [
- cert.ca.get_authority_key_identifier_extension(),
- basic_constraints(),
- x509.Extension(
- oid=ExtensionOID.CRL_DISTRIBUTION_POINTS,
- critical=True,
- value=x509.CRLDistributionPoints(
- [
- x509.DistributionPoint(
- full_name=[uri("http://crl.profile.example.com")],
- relative_name=None,
- reasons=None,
- crl_issuer=[uri("http://crl-issuer.profile.example.com")],
- ),
- x509.DistributionPoint(
- full_name=[uri("http://crl2.profile.example.com")],
- relative_name=None,
- reasons=None,
- crl_issuer=[uri("http://crl-issuer2.profile.example.com")],
- ),
- ]
- ),
- ),
- self.freshest_crl(
- [uri("http://freshest-crl.profile.example.com")],
- crl_issuer=[uri("http://freshest-crl-issuer.profile.example.com")],
- critical=False,
+ assert cert.profile == "everything"
+ assert cert.sorted_extensions == [
+ cert.ca.get_authority_key_identifier_extension(),
+ basic_constraints(),
+ x509.Extension(
+ oid=ExtensionOID.CRL_DISTRIBUTION_POINTS,
+ critical=True,
+ value=x509.CRLDistributionPoints(
+ [
+ x509.DistributionPoint(
+ full_name=[uri("http://crl.profile.example.com")],
+ relative_name=None,
+ reasons=None,
+ crl_issuer=[uri("http://crl-issuer.profile.example.com")],
+ ),
+ x509.DistributionPoint(
+ full_name=[uri("http://crl2.profile.example.com")],
+ relative_name=None,
+ reasons=None,
+ crl_issuer=[uri("http://crl-issuer2.profile.example.com")],
+ ),
+ ]
),
- subject_key_identifier(cert),
- ],
- )
+ ),
+ self.freshest_crl(
+ [uri("http://freshest-crl.profile.example.com")],
+ crl_issuer=[uri("http://freshest-crl-issuer.profile.example.com")],
+ critical=False,
+ ),
+ subject_key_identifier(cert),
+ ]
diff --git a/ca/django_ca/tests/admin/test_admin_ca.py b/ca/django_ca/tests/admin/test_admin_ca.py
index 438fc04f4..54871e28e 100644
--- a/ca/django_ca/tests/admin/test_admin_ca.py
+++ b/ca/django_ca/tests/admin/test_admin_ca.py
@@ -54,7 +54,7 @@ def test_complex_sign_certificate_policies(self) -> None:
# This test is only meaningful if the CA does **not** have the Certificate Policies extension in its
# own extensions. We (can) only test for the used template after viewing, and the template would be
# used for that extension.
- self.assertNotIn(ExtensionOID.CERTIFICATE_POLICIES, ca.extensions)
+ assert ExtensionOID.CERTIFICATE_POLICIES not in ca.extensions
ca.sign_certificate_policies = certificate_policies(
x509.PolicyInformation(
@@ -80,7 +80,7 @@ def test_complex_sign_certificate_policies(self) -> None:
response = self.get_change_view(ca)
assert_change_response(response)
templates = [t.name for t in response.templates]
- self.assertIn("django_ca/admin/extensions/2.5.29.32.html", templates)
+ assert "django_ca/admin/extensions/2.5.29.32.html" in templates
class CADownloadBundleTestCase(AdminTestCaseMixin[CertificateAuthority], TestCase):
@@ -111,13 +111,13 @@ def test_child(self) -> None:
def test_invalid_format(self) -> None:
"""Test downloading the bundle in an invalid format."""
response = self.client.get(f"{self.url}?format=INVALID")
- self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
- self.assertEqual(response.content, b"")
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert response.content == b""
# DER is not supported for bundles
response = self.client.get(f"{self.url}?format=DER")
- self.assertEqual(response.status_code, 400)
- self.assertEqual(response.content, b"DER/ASN.1 certificates cannot be downloaded as a bundle.")
+ assert response.status_code == 400
+ assert response.content == b"DER/ASN.1 certificates cannot be downloaded as a bundle."
def test_permission_denied(self) -> None:
"""Test downloading without permissions fails."""
@@ -125,7 +125,7 @@ def test_permission_denied(self) -> None:
self.user.save()
response = self.client.get(f"{self.url}?format=PEM")
- self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN)
+ assert response.status_code == HTTPStatus.FORBIDDEN
def test_unauthorized(self) -> None:
"""Test viewing as unauthorized viewer."""
diff --git a/ca/django_ca/tests/admin/test_extra_views.py b/ca/django_ca/tests/admin/test_extra_views.py
index 78d736aca..c339b009e 100644
--- a/ca/django_ca/tests/admin/test_extra_views.py
+++ b/ca/django_ca/tests/admin/test_extra_views.py
@@ -38,7 +38,7 @@
@pytest.mark.parametrize(
- "data,expected",
+ ("data", "expected"),
(
([], ""),
([{"oid": NameOID.COMMON_NAME.dotted_string, "value": "example.com"}], "CN=example.com"),
@@ -107,12 +107,11 @@ def test_basic(self) -> None:
response = self.client.post(
self.url, data=json.dumps({"csr": csr}), content_type="application/json"
)
- self.assertEqual(response.status_code, 200, response.json())
+ assert response.status_code == 200, response.json()
csr_subject = cert_data["csr"]["parsed"].subject
- self.assertEqual(
- response.json(),
- {"subject": [{"oid": s.oid.dotted_string, "value": s.value} for s in csr_subject]},
- )
+ assert response.json() == {
+ "subject": [{"oid": s.oid.dotted_string, "value": s.value} for s in csr_subject]
+ }
def test_fields(self) -> None:
"""Test fetching a CSR with all subject fields."""
@@ -129,7 +128,7 @@ def test_fields(self) -> None:
response = self.client.post(
self.url, data=json.dumps({"csr": csr_pem}), content_type="application/json"
)
- self.assertEqual(response.status_code, 200, response.json())
+ assert response.status_code == 200, response.json()
expected = [
{"oid": NameOID.USER_ID.dotted_string, "value": "test-uid"},
{"oid": NameOID.DOMAIN_COMPONENT.dotted_string, "value": "test-domainComponent"},
@@ -169,12 +168,12 @@ def test_fields(self) -> None:
{"oid": NameOID.ORGANIZATION_IDENTIFIER.dotted_string, "value": "test-organizationIdentifier"},
]
- self.assertEqual(json.loads(response.content.decode("utf-8")), {"subject": expected})
+ assert json.loads(response.content.decode("utf-8")) == {"subject": expected}
def test_bad_request(self) -> None:
"""Test posting bogus data."""
response = self.client.post(self.url, data={"csr": "foobar"})
- self.assertEqual(response.status_code, 400)
+ assert response.status_code == 400
def test_anonymous(self) -> None:
"""Try downloading as anonymous user."""
@@ -192,7 +191,7 @@ def test_no_perms(self) -> None:
self.user.is_superuser = False
self.user.save()
response = self.client.post(self.url, data={"csr": self.csr_pem})
- self.assertEqual(response.status_code, 403)
+ assert response.status_code == 403
def test_no_staff(self) -> None:
"""Try downloading as user that has permissions but is not staff."""
@@ -217,22 +216,22 @@ def test_der(self) -> None:
"""Download a certificate in DER format."""
filename = f"{self.cert.serial}.der"
response = self.client.get(self.get_url(self.cert), {"format": "DER"})
- self.assertEqual(response.status_code, HTTPStatus.OK)
- self.assertEqual(response["Content-Type"], "application/pkix-cert")
- self.assertEqual(response["Content-Disposition"], f"attachment; filename={filename}")
- self.assertEqual(response.content, self.cert.pub.der)
+ assert response.status_code == HTTPStatus.OK
+ assert response["Content-Type"] == "application/pkix-cert"
+ assert response["Content-Disposition"] == f"attachment; filename={filename}"
+ assert response.content == self.cert.pub.der
def test_not_found(self) -> None:
"""Try downloading a certificate that does not exist."""
url = reverse("admin:django_ca_certificate_download", kwargs={"pk": "123"})
response = self.client.get(f"{url}?format=DER")
- self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)
+ assert response.status_code == HTTPStatus.NOT_FOUND
def test_bad_format(self) -> None:
"""Try downloading an unknown format."""
response = self.client.get(self.get_url(self.cert), {"format": "bad"})
- self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
- self.assertEqual(response.content, b"")
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert response.content == b""
def test_anonymous(self) -> None:
"""Try an anonymous download."""
@@ -249,7 +248,7 @@ def test_no_perms(self) -> None:
self.user.is_superuser = False
self.user.save()
response = self.client.get(self.get_url(self.cert))
- self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN)
+ assert response.status_code == HTTPStatus.FORBIDDEN
def test_no_staff(self) -> None:
"""Try downloading with right permissions but not as staff user."""
@@ -283,10 +282,10 @@ def test_invalid_format(self) -> None:
"""Try downloading an invalid format."""
url = self.get_url(self.cert)
response = self.client.get(f"{url}?format=INVALID")
- self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
- self.assertEqual(response.content, b"")
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert response.content == b""
# DER is not supported for bundles
response = self.client.get(f"{url}?format=DER")
- self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
- self.assertEqual(response.content, b"DER/ASN.1 certificates cannot be downloaded as a bundle.")
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert response.content == b"DER/ASN.1 certificates cannot be downloaded as a bundle."
diff --git a/ca/django_ca/tests/base/assertions.py b/ca/django_ca/tests/base/assertions.py
index 37463dda7..4d61363ed 100644
--- a/ca/django_ca/tests/base/assertions.py
+++ b/ca/django_ca/tests/base/assertions.py
@@ -13,13 +13,14 @@
""":py:mod:`django_ca.tests.base.assertions` collects assertions used throughout the entire test suite."""
+import collections
import io
import re
import typing
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone as tz
-from typing import AnyStr, Optional, Union
+from typing import Any, AnyStr, Optional, Union
from unittest.mock import Mock
from cryptography import x509
@@ -157,6 +158,12 @@ def assert_command_error(msg: str, returncode: int = 1) -> Iterator[None]:
assert exc_info.value.returncode == returncode
+def assert_count_equal(first: Iterable[Any], second: Iterable[Any]) -> None:
+ """Roughly equivalent version of unittests assertCountEqual()."""
+ first, second = list(first), list(second)
+ assert collections.Counter(first) == collections.Counter(second)
+
+
@contextmanager
def assert_create_ca_signals(pre: bool = True, post: bool = True) -> Iterator[tuple[Mock, Mock]]:
"""Context manager asserting that the `pre_create_ca`/`post_create_ca` signals are (not) called."""
@@ -305,6 +312,34 @@ def assert_e2e_error(
raise NotImplementedError
+def assert_extension_equal(
+ first: Optional[x509.Extension[x509.ExtensionType]], second: Optional[x509.Extension[x509.ExtensionType]]
+) -> None:
+ """Compare two extensions for equality (or if both are None).
+
+ This assertion overrides comparison for iterable extension and should be used only when order of these
+ extension values cannot be guaranteed. For example, two ExtendedKeyUsage extension will pass as equal
+ regardless of order of the extended key usages in the extensions.
+ """
+ # If both are None that's still okay.
+ if first is None and second is None:
+ return
+ if first is None or second is None: # pragma: no cover
+ raise AssertionError("One of the values is None.")
+
+ if second.oid in (
+ ExtensionOID.EXTENDED_KEY_USAGE,
+ ExtensionOID.TLS_FEATURE,
+ ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
+ ExtensionOID.ISSUER_ALTERNATIVE_NAME,
+ ):
+ assert first.oid == second.oid
+ assert first.critical == second.critical
+ assert_count_equal(first.value, second.value) # type: ignore[arg-type]
+ else:
+ assert first == second
+
+
def assert_extensions(
cert: Union[X509CertMixin, x509.Certificate],
extensions: Iterable[x509.Extension[x509.ExtensionType]],
diff --git a/ca/django_ca/tests/base/conftest_helpers.py b/ca/django_ca/tests/base/conftest_helpers.py
index cde721765..340cf6889 100644
--- a/ca/django_ca/tests/base/conftest_helpers.py
+++ b/ca/django_ca/tests/base/conftest_helpers.py
@@ -146,24 +146,24 @@ def setup_pragmas(cov: coverage.Coverage) -> None:
exclude_versions(cov, "cryptography", cg_version, ver, version_str)
-def generate_pub_fixture(name: str) -> typing.Callable[[], Iterator[x509.Certificate]]:
+def generate_pub_fixture(name: str) -> typing.Callable[[], x509.Certificate]:
"""Generate fixture for a loaded public key (root_pub, root_cert_pub, ...)."""
@pytest.fixture(scope="session")
- def fixture() -> Iterator[x509.Certificate]:
+ def fixture() -> x509.Certificate:
return load_pub(name)
return fixture
-def generate_ca_fixture(name: str) -> typing.Callable[["SubRequest", Any], Iterator[CertificateAuthority]]:
+def generate_ca_fixture(name: str) -> typing.Callable[["SubRequest", Any], CertificateAuthority]:
"""Function to generate CA fixtures (root, child, ...)."""
@pytest.fixture
def fixture(
request: "SubRequest",
db: Any, # pylint: disable=unused-argument # usefixtures does not work for fixtures
- ) -> Iterator[CertificateAuthority]:
+ ) -> CertificateAuthority:
data = CERT_DATA[name]
ca_fixture_name = f"{name}_pub"
if data["cat"] == "sphinx-contrib":
@@ -187,27 +187,25 @@ def fixture(
return fixture
-def generate_usable_ca_fixture(
- name: str,
-) -> typing.Callable[["SubRequest", Path], Iterator[CertificateAuthority]]:
+def generate_usable_ca_fixture(name: str) -> typing.Callable[["SubRequest", Path], CertificateAuthority]:
"""Function to generate CA fixtures (root, child, ...)."""
@pytest.fixture
- def fixture(request: "SubRequest", tmpcadir: Path) -> Iterator[CertificateAuthority]:
+ def fixture(request: "SubRequest", tmpcadir: Path) -> CertificateAuthority:
ca = request.getfixturevalue(name) # load the CA into the database
data = CERT_DATA[name]
shutil.copy(os.path.join(FIXTURES_DIR, data["key_filename"]), tmpcadir)
- return ca
+ return ca # type: ignore[no-any-return]
return fixture
-def generate_cert_fixture(name: str) -> typing.Callable[["SubRequest"], Iterator[Certificate]]:
+def generate_cert_fixture(name: str) -> typing.Callable[["SubRequest"], Certificate]:
"""Function to generate cert fixtures (root_cert, all_extensions, no_extensions, ...)."""
@pytest.fixture
- def fixture(request: "SubRequest") -> Iterator[Certificate]:
+ def fixture(request: "SubRequest") -> Certificate:
sanitized_name = name.replace("-", "_")
data = CERT_DATA[name]
diff --git a/ca/django_ca/tests/base/fixtures.py b/ca/django_ca/tests/base/fixtures.py
index fe12d68bb..49e3cb21c 100644
--- a/ca/django_ca/tests/base/fixtures.py
+++ b/ca/django_ca/tests/base/fixtures.py
@@ -58,15 +58,15 @@
@pytest.fixture(params=all_cert_names)
-def any_cert(request: "SubRequest") -> Iterator[Certificate]:
+def any_cert(request: "SubRequest") -> Certificate:
"""Parametrized fixture for absolutely *any* certificate name."""
- return request.param
+ return request.param # type: ignore[no-any-return]
@pytest.fixture
-def ca_name(request: "SubRequest") -> Iterator[str]:
+def ca_name(request: "SubRequest") -> str:
"""Fixture for a name suitable for a CA."""
- return request.node.name
+ return request.node.name # type: ignore[no-any-return]
@pytest.fixture(
@@ -163,7 +163,7 @@ def ca_name(request: "SubRequest") -> Iterator[str]:
],
)
)
-def certificate_policies_value(request: "SubRequest") -> Iterator[x509.CertificatePolicies]:
+def certificate_policies_value(request: "SubRequest") -> x509.CertificatePolicies:
"""Parametrized fixture with different :py:class:`~cg:cryptography.x509.CertificatePolicies` objects."""
return x509.CertificatePolicies(policies=request.param)
@@ -171,7 +171,7 @@ def certificate_policies_value(request: "SubRequest") -> Iterator[x509.Certifica
@pytest.fixture(params=(True, False))
def certificate_policies(
request: "SubRequest", certificate_policies_value: x509.CertificatePolicies
-) -> Iterator[x509.Extension[x509.CertificatePolicies]]:
+) -> x509.Extension[x509.CertificatePolicies]:
"""Parametrized fixture yielding different ``x509.Extension[x509.CertificatePolicies]`` objects."""
return x509.Extension(
critical=request.param, oid=ExtensionOID.CERTIFICATE_POLICIES, value=certificate_policies_value
@@ -186,13 +186,13 @@ def clear_cache() -> Iterator[None]:
@pytest.fixture(params=("ed448", "ed25519"))
-def ed_ca(request: "SubRequest") -> Iterator[CertificateAuthority]:
+def ed_ca(request: "SubRequest") -> CertificateAuthority:
"""Parametrized fixture for CAs with an Edwards-curve algorithm (ed448, ed25519)."""
- return request.getfixturevalue(f"{request.param}")
+ return request.getfixturevalue(f"{request.param}") # type: ignore[no-any-return]
@pytest.fixture
-def hostname(ca_name: str) -> Iterator[str]:
+def hostname(ca_name: str) -> str:
"""Fixture for a hostname.
The value is unique for each test, and it includes the CA name, which includes the test name.
@@ -201,30 +201,30 @@ def hostname(ca_name: str) -> Iterator[str]:
@pytest.fixture(params=interesting_certificate_names)
-def interesting_cert(request: "SubRequest") -> Iterator[Certificate]:
+def interesting_cert(request: "SubRequest") -> Certificate:
"""Parametrized fixture for "interesting" certificates.
A function using this fixture will be called once for each certificate with unusual extensions.
"""
- return request.getfixturevalue(request.param.replace("-", "_"))
+ return request.getfixturevalue(request.param.replace("-", "_")) # type: ignore[no-any-return]
@pytest.fixture
-def key_backend(request: "SubRequest") -> Iterator[StoragesBackend]:
+def key_backend(request: "SubRequest") -> StoragesBackend:
"""Return a :py:class:`~django_ca.key_backends.storages.StoragesBackend` for creating a new CA."""
request.getfixturevalue("tmpcadir")
- return key_backends[model_settings.CA_DEFAULT_KEY_BACKEND] # type: ignore[misc]
+ return key_backends[model_settings.CA_DEFAULT_KEY_BACKEND] # type: ignore[return-value]
@pytest.fixture(params=precertificate_signed_certificate_timestamps_cert_names)
-def precertificate_signed_certificate_timestamps_pub(request: "SubRequest") -> Iterator[x509.Certificate]:
+def precertificate_signed_certificate_timestamps_pub(request: "SubRequest") -> x509.Certificate:
"""Parametrized fixture for certificates that have a PrecertSignedCertificateTimestamps extension."""
name = request.param.replace("-", "_")
- return request.getfixturevalue(f"contrib_{name}_pub")
+ return request.getfixturevalue(f"contrib_{name}_pub") # type: ignore[no-any-return]
@pytest.fixture
-def rfc4514_subject(subject: x509.Name) -> Iterator[str]:
+def rfc4514_subject(subject: x509.Name) -> str:
"""Fixture for an RFC 4514 formatted name to use for a subject.
The common name is based on :py:func:`~django_ca.tests.base.fixtures.hostname` and identical to
@@ -234,7 +234,7 @@ def rfc4514_subject(subject: x509.Name) -> Iterator[str]:
@pytest.fixture
-def root_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]:
+def root_crl(root: CertificateAuthority) -> CertificateRevocationList:
"""Fixture for the global CRL object for the Root CA."""
with open(constants.FIXTURES_DIR / "root.crl", "rb") as stream:
crl_data = stream.read()
@@ -248,7 +248,7 @@ def root_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]:
@pytest.fixture
-def root_ca_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]:
+def root_ca_crl(root: CertificateAuthority) -> CertificateRevocationList:
"""Fixture for the user CRL object for the Root CA."""
with open(constants.FIXTURES_DIR / "root.ca.crl", "rb") as stream:
crl_data = stream.read()
@@ -267,7 +267,7 @@ def root_ca_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationLis
@pytest.fixture
-def root_user_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]:
+def root_user_crl(root: CertificateAuthority) -> CertificateRevocationList:
"""Fixture for the user CRL object for the Root CA."""
with open(constants.FIXTURES_DIR / "root.user.crl", "rb") as stream:
crl_data = stream.read()
@@ -286,7 +286,7 @@ def root_user_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationL
@pytest.fixture
-def root_attribute_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]:
+def root_attribute_crl(root: CertificateAuthority) -> CertificateRevocationList:
"""Fixture for the attribute CRL object for the Root CA."""
with open(constants.FIXTURES_DIR / "root.attribute.crl", "rb") as stream:
crl_data = stream.read()
@@ -305,30 +305,28 @@ def root_attribute_crl(root: CertificateAuthority) -> Iterator[CertificateRevoca
@pytest.fixture
-def secondary_backend(request: "SubRequest") -> Iterator[StoragesBackend]:
+def secondary_backend(request: "SubRequest") -> StoragesBackend:
"""Return a :py:class:`~django_ca.key_backends.storages.StoragesBackend` for the secondary key backend."""
request.getfixturevalue("tmpcadir")
- return key_backends["secondary"] # type: ignore[misc]
+ return key_backends["secondary"] # type: ignore[return-value]
@pytest.fixture(params=signed_certificate_timestamp_cert_names)
-def signed_certificate_timestamp_pub(request: "SubRequest") -> Iterator[x509.Certificate]:
+def signed_certificate_timestamp_pub(request: "SubRequest") -> x509.Certificate:
"""Parametrized fixture for certificates that have any SCT extension."""
name = request.param.replace("-", "_")
- return request.getfixturevalue(f"contrib_{name}_pub")
+ return request.getfixturevalue(f"contrib_{name}_pub") # type: ignore[no-any-return]
@pytest.fixture(params=signed_certificate_timestamps_cert_names)
-def signed_certificate_timestamps_pub(
- request: "SubRequest",
-) -> Iterator[x509.Certificate]: # pragma: no cover
+def signed_certificate_timestamps_pub(request: "SubRequest") -> x509.Certificate: # pragma: no cover
"""Parametrized fixture for certificates that have a SignedCertificateTimestamps extension.
.. NOTE:: There are no certificates with this extension right now, so this fixture is in fact never run.
"""
name = request.param.replace("-", "_")
- return request.getfixturevalue(f"{name}_pub")
+ return request.getfixturevalue(f"{name}_pub") # type: ignore[no-any-return]
@pytest.fixture
@@ -370,7 +368,7 @@ def softhsm_setup(tmp_path: Path) -> Iterator[Path]: # pragma: hsm
def softhsm_token( # pragma: hsm
request: "SubRequest",
settings: SettingsWrapper,
-) -> Iterator[str]:
+) -> str:
"""Get a unique token for the current test."""
request.getfixturevalue("softhsm_setup")
token = settings.PKCS11_TOKEN_LABEL
@@ -390,7 +388,7 @@ def softhsm_token( # pragma: hsm
if lib := SessionPool._lib_pool.get(settings.PKCS11_PATH): # pylint: disable=protected-access
lib.reinitialize()
- return token
+ return token # type: ignore[no-any-return]
@pytest.fixture
@@ -404,7 +402,7 @@ def hsm_backend(request: "SubRequest") -> Iterator[HSMBackend]: # pragma: hsm
@pytest.fixture(params=HSMBackend.supported_key_types)
def usable_hsm_ca( # pragma: hsm
request: "SubRequest", ca_name: str, subject: x509.Name, hsm_backend: HSMBackend
-) -> Iterator[CertificateAuthority]:
+) -> CertificateAuthority:
"""Parametrized fixture yielding a certificate authority for every key type."""
request.getfixturevalue("db")
key_type = request.param
@@ -428,7 +426,7 @@ def usable_hsm_ca( # pragma: hsm
@pytest.fixture
-def subject(hostname: str) -> Iterator[x509.Name]:
+def subject(hostname: str) -> x509.Name:
"""Fixture for a :py:class:`~cg:cryptography.x509.Name` to use for a subject.
The common name is based on :py:func:`~django_ca.tests.base.fixtures.hostname` and identical to
@@ -474,28 +472,28 @@ def tmpcadir(tmp_path: Path, settings: SettingsWrapper) -> Iterator[Path]:
@pytest.fixture(params=all_ca_names)
-def ca(request: "SubRequest") -> Iterator[CertificateAuthority]:
+def ca(request: "SubRequest") -> CertificateAuthority:
"""Parametrized fixture for all certificate authorities known to the test suite."""
fixture_name = request.param
if CERT_DATA[fixture_name]["cat"] in ("contrib", "sphinx-contrib"):
fixture_name = f"contrib_{fixture_name}"
- return request.getfixturevalue(fixture_name)
+ return request.getfixturevalue(fixture_name) # type: ignore[no-any-return]
@pytest.fixture(params=usable_ca_names)
-def usable_ca_name(request: "SubRequest") -> Iterator[CertificateAuthority]:
+def usable_ca_name(request: "SubRequest") -> CertificateAuthority:
"""Parametrized fixture for the name of every usable CA."""
- return request.param
+ return request.param # type: ignore[no-any-return]
@pytest.fixture(params=usable_ca_names)
-def usable_ca(request: "SubRequest") -> Iterator[CertificateAuthority]:
+def usable_ca(request: "SubRequest") -> CertificateAuthority:
"""Parametrized fixture for every usable CA (with usable private key)."""
- return request.getfixturevalue(f"usable_{request.param}")
+ return request.getfixturevalue(f"usable_{request.param}") # type: ignore[no-any-return]
@pytest.fixture
-def usable_cas(request: "SubRequest") -> Iterator[list[CertificateAuthority]]:
+def usable_cas(request: "SubRequest") -> list[CertificateAuthority]:
"""Fixture for all usable CAs as a list."""
cas = []
for name in usable_ca_names:
@@ -504,7 +502,7 @@ def usable_cas(request: "SubRequest") -> Iterator[list[CertificateAuthority]]:
@pytest.fixture(params=usable_cert_names)
-def usable_cert(request: "SubRequest") -> Iterator[Certificate]:
+def usable_cert(request: "SubRequest") -> Certificate:
"""Parametrized fixture for every ``{ca}-cert`` certificate.
The name of the certificate can be retrieved from the non-standard `test_name` property of the
@@ -514,4 +512,4 @@ def usable_cert(request: "SubRequest") -> Iterator[Certificate]:
cert = request.getfixturevalue(name.replace("-", "_"))
cert.test_name = name
request.getfixturevalue(f"usable_{cert.ca.name}")
- return cert
+ return cert # type: ignore[no-any-return]
diff --git a/ca/django_ca/tests/base/mixins.py b/ca/django_ca/tests/base/mixins.py
index 9e11ac154..cae1f42fb 100644
--- a/ca/django_ca/tests/base/mixins.py
+++ b/ca/django_ca/tests/base/mixins.py
@@ -68,13 +68,6 @@ class TestCaseMixin(TestCaseProtocol):
re_false_password = r"^Could not decrypt private key - bad password\?$"
def setUp(self) -> None:
- # Add custom equality functions
- self.addTypeEqualityFunc(x509.AuthorityInformationAccess, self.assertAuthorityInformationAccessEqual)
- self.addTypeEqualityFunc(x509.ExtendedKeyUsage, self.assertExtendedKeyUsageEqual)
- self.addTypeEqualityFunc(x509.Extension, self.assertCryptographyExtensionEqual)
- self.addTypeEqualityFunc(x509.KeyUsage, self.assertKeyUsageEqual)
- self.addTypeEqualityFunc(x509.TLSFeature, self.assertTLSFeatureEqual)
-
super().setUp()
cache.clear()
@@ -152,95 +145,18 @@ def absolute_uri(self, name: str, hostname: Optional[str] = None, **kwargs: Any)
name = f"django_ca{name}"
return f"http://{hostname}{reverse(name, kwargs=kwargs)}"
- def assertAuthorityInformationAccessEqual( # pylint: disable=invalid-name
- self,
- first: x509.AuthorityInformationAccess,
- second: x509.AuthorityInformationAccess,
- msg: Optional[str] = None,
- ) -> None:
- """Type equality function for x509.AuthorityInformationAccess."""
-
- def sorter(ad: x509.AccessDescription) -> tuple[str, str]:
- return ad.access_method.dotted_string, ad.access_location.value
-
- self.assertEqual(sorted(first, key=sorter), sorted(second, key=sorter), msg=msg)
-
- def assertCryptographyExtensionEqual( # pylint: disable=invalid-name
- self,
- first: x509.Extension[x509.ExtensionType],
- second: x509.Extension[x509.ExtensionType],
- msg: Optional[str] = None,
- ) -> None:
- """Type equality function for x509.Extension."""
- # NOTE: Cryptography in name comes from overriding class in AbstractExtensionTestMixin
- # remove once old wrapper classes are removed
- self.assertEqual(first.oid, second.oid, msg=msg)
- self.assertEqual(first.critical, second.critical, msg="critical is unequal.")
- self.assertEqual(first.value, second.value, msg=msg)
-
- def assertExtendedKeyUsageEqual( # pylint: disable=invalid-name
- self, first: x509.ExtendedKeyUsage, second: x509.ExtendedKeyUsage, msg: Optional[str] = None
- ) -> None:
- """Type equality function for x509.ExtendedKeyUsage."""
- self.assertEqual(set(first), set(second), msg=msg)
-
- def assertKeyUsageEqual( # pylint: disable=invalid-name
- self, first: x509.KeyUsage, second: x509.KeyUsage, msg: Optional[str] = None
- ) -> None:
- """Type equality function for x509.KeyUsage."""
- diffs = []
- for usage in [
- "content_commitment",
- "crl_sign",
- "data_encipherment",
- "decipher_only",
- "digital_signature",
- "encipher_only",
- "key_agreement",
- "key_cert_sign",
- "key_encipherment",
- ]:
- try:
- first_val = getattr(first, usage)
- except ValueError:
- first_val = False
- try:
- second_val = getattr(second, usage)
- except ValueError:
- second_val = False
-
- if first_val != second_val: # pragma: no cover # would only be run in case of error
- diffs.append(f" * {usage}: {first_val} -> {second_val}")
-
- if msg is None:
- msg = "KeyUsage extensions differ:"
- if diffs: # pragma: no cover # would only be run in case of error
- raise self.failureException(msg + "\n" + "\n".join(diffs))
-
- def assertTLSFeatureEqual( # pylint: disable=invalid-name
- self, first: x509.TLSFeature, second: x509.TLSFeature, msg: Optional[str] = None
- ) -> None:
- """Type equality function for x509.TLSFeature."""
- self.assertEqual(set(first), set(second), msg=msg)
-
- def assertIssuer( # pylint: disable=invalid-name
- self, issuer: CertificateAuthority, cert: X509CertMixin
- ) -> None:
- """Assert that the issuer for `cert` matches the subject of `issuer`."""
- self.assertEqual(cert.issuer, issuer.subject)
-
def assertMessages( # pylint: disable=invalid-name
self, response: "HttpResponse", expected: list[str]
) -> None:
"""Assert given Django messages for `response`."""
messages = [str(m) for m in list(get_messages(response.wsgi_request))]
- self.assertEqual(messages, expected)
+ assert messages == expected
def assertNotRevoked(self, cert: X509CertMixin) -> None: # pylint: disable=invalid-name
"""Assert that the certificate is not revoked."""
cert.refresh_from_db()
- self.assertFalse(cert.revoked)
- self.assertEqual(cert.revoked_reason, "")
+ assert not cert.revoked
+ assert cert.revoked_reason == ""
def assertPostRevoke(self, post: mock.Mock, cert: Certificate) -> None: # pylint: disable=invalid-name
"""Assert that the post_revoke_cert signal was called."""
@@ -381,13 +297,13 @@ def mute_celery(self, *calls: Any) -> Iterator[mock.MagicMock]:
# Make sure that all invocations are JSON serializable
for invocation in mocked.call_args_list:
# invocation apply_async() has task args as arg[0] and arg[1]
- self.assertIsInstance(json.dumps(invocation.args[0]), str)
- self.assertIsInstance(json.dumps(invocation.args[1]), str)
+ assert isinstance(json.dumps(invocation.args[0]), str)
+ assert isinstance(json.dumps(invocation.args[1]), str)
# Make sure that task was called the right number of times
- self.assertEqual(len(calls), len(mocked.call_args_list))
+ assert len(calls) == len(mocked.call_args_list)
for expected, actual in zip(calls, mocked.call_args_list):
- self.assertEqual(expected, actual, actual)
+ assert expected == actual, actual
@contextmanager
def patch(self, *args: Any, **kwargs: Any) -> Iterator[mock.MagicMock]:
@@ -445,10 +361,10 @@ def assertBundle( # pylint: disable=invalid-name
expected_content = "\n".join([e.pub.pem.strip() for e in expected]) + "\n"
response = self.client.get(url, {"format": "PEM"})
- self.assertEqual(response.status_code, HTTPStatus.OK)
- self.assertEqual(response["Content-Type"], "application/pkix-cert")
- self.assertEqual(response["Content-Disposition"], f"attachment; filename={filename}")
- self.assertEqual(response.content.decode("utf-8"), expected_content)
+ assert response.status_code == HTTPStatus.OK
+ assert response["Content-Type"] == "application/pkix-cert"
+ assert response["Content-Disposition"] == f"attachment; filename={filename}"
+ assert response.content.decode("utf-8") == expected_content
def assertRequiresLogin( # pylint: disable=invalid-name
self, response: "HttpResponse", **kwargs: Any
@@ -511,7 +427,7 @@ def get_changelists(
def test_model_count(self) -> None:
"""Test that the implementing TestCase actually creates some instances."""
- self.assertGreater(self.model._default_manager.all().count(), 0)
+ assert self.model._default_manager.all().count() > 0
def test_changelist_view(self) -> None:
"""Test that the changelist view works."""
diff --git a/ca/django_ca/tests/commands/test_dump_ca.py b/ca/django_ca/tests/commands/test_dump_ca.py
index da410fbfc..4baff39e6 100644
--- a/ca/django_ca/tests/commands/test_dump_ca.py
+++ b/ca/django_ca/tests/commands/test_dump_ca.py
@@ -41,7 +41,7 @@ def test_basic(root: CertificateAuthority) -> None:
assert stdout.decode() == root.pub.pem
-@pytest.mark.parametrize("encoding", [Encoding.PEM, Encoding.DER])
+@pytest.mark.parametrize("encoding", (Encoding.PEM, Encoding.DER))
def test_format(root: CertificateAuthority, encoding: Encoding) -> None:
"""Test encoding parameter."""
stdout = dump_ca(root.serial, format=encoding)
diff --git a/ca/django_ca/tests/commands/test_dump_crl.py b/ca/django_ca/tests/commands/test_dump_crl.py
index 26e5852a2..63dae914b 100644
--- a/ca/django_ca/tests/commands/test_dump_crl.py
+++ b/ca/django_ca/tests/commands/test_dump_crl.py
@@ -144,7 +144,7 @@ def test_disabled(usable_root: CertificateAuthority) -> None:
assert_crl(stdout, signer=usable_root, algorithm=usable_root.algorithm)
-@pytest.mark.parametrize("reason", [x509.ReasonFlags.unspecified, x509.ReasonFlags.key_compromise])
+@pytest.mark.parametrize("reason", (x509.ReasonFlags.unspecified, x509.ReasonFlags.key_compromise))
def test_revoked_with_reason(
usable_root: CertificateAuthority, root_cert: Certificate, reason: x509.ReasonFlags
) -> None:
diff --git a/ca/django_ca/tests/commands/test_init_ca.py b/ca/django_ca/tests/commands/test_init_ca.py
index 262ed2ab9..20ae06e9f 100644
--- a/ca/django_ca/tests/commands/test_init_ca.py
+++ b/ca/django_ca/tests/commands/test_init_ca.py
@@ -867,7 +867,7 @@ def test_password(ca_name: str, key_backend: StoragesBackend) -> None:
key_backend.get_key(parent, use_options)
# Wrong password doesn't work either
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError): # noqa: PT011 # cryptography controls the error message
# NOTE: cryptography is notoriously unstable when it comes to the error message here, so we only
# check the exception class.
key_backend.get_key(parent, StoragesUsePrivateKeyOptions(password=b"wrong"))
@@ -1483,7 +1483,7 @@ def test_key_size_with_unsupported_key_type(ca_name: str, key_type: str) -> None
@pytest.mark.skipif(CRYPTOGRAPHY_VERSION < (43,), reason="cryptography check was added in version 43")
@pytest.mark.parametrize(
- "value,msg",
+ ("value", "msg"),
(
("", r"Attribute's length must be >= 1 and <= 64, but it was 0"),
("X" * 65, r"Attribute's length must be >= 1 and <= 64, but it was 65"),
diff --git a/ca/django_ca/tests/commands/test_list_cas.py b/ca/django_ca/tests/commands/test_list_cas.py
index e8bc999a1..4d1da9f36 100644
--- a/ca/django_ca/tests/commands/test_list_cas.py
+++ b/ca/django_ca/tests/commands/test_list_cas.py
@@ -48,7 +48,7 @@ def assertOutput( # pylint: disable=invalid-name
context.update(CERT_DATA)
for ca_name in self.cas:
context.setdefault(f"{ca_name}_state", "")
- self.assertEqual(output, expected.format(**context))
+ assert output == expected.format(**context)
def test_all_cas(self) -> None:
"""Test list with all CAs."""
@@ -56,54 +56,58 @@ def test_all_cas(self) -> None:
self.load_ca(name)
stdout, stderr = cmd("list_cas")
- self.assertEqual(
- stdout,
- f"""{CERT_DATA['letsencrypt_x1']['serial_colons']} - {CERT_DATA['letsencrypt_x1']['name']}
-{CERT_DATA['letsencrypt_x3']['serial_colons']} - {CERT_DATA['letsencrypt_x3']['name']}
-{CERT_DATA['dst_root_x3']['serial_colons']} - {CERT_DATA['dst_root_x3']['name']}
-{CERT_DATA['google_g3']['serial_colons']} - {CERT_DATA['google_g3']['name']}
-{CERT_DATA['globalsign_r2_root']['serial_colons']} - {CERT_DATA['globalsign_r2_root']['name']}
-{CERT_DATA['trustid_server_a52']['serial_colons']} - {CERT_DATA['trustid_server_a52']['name']}
-{CERT_DATA['rapidssl_g3']['serial_colons']} - {CERT_DATA['rapidssl_g3']['name']}
-{CERT_DATA['geotrust']['serial_colons']} - {CERT_DATA['geotrust']['name']}
-{CERT_DATA['startssl_class2']['serial_colons']} - {CERT_DATA['startssl_class2']['name']}
-{CERT_DATA['digicert_sha2']['serial_colons']} - {CERT_DATA['digicert_sha2']['name']}
-{CERT_DATA['globalsign_dv']['serial_colons']} - {CERT_DATA['globalsign_dv']['name']}
-{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}
-{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}
-{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}
-{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}
-{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}
-{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}
-{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}
-{CERT_DATA['comodo_ev']['serial_colons']} - {CERT_DATA['comodo_ev']['name']}
-{CERT_DATA['globalsign']['serial_colons']} - {CERT_DATA['globalsign']['name']}
-{CERT_DATA['digicert_ha_intermediate']['serial_colons']} - {CERT_DATA['digicert_ha_intermediate']['name']}
-{CERT_DATA['comodo_dv']['serial_colons']} - {CERT_DATA['comodo_dv']['name']}
-{CERT_DATA['startssl_class3']['serial_colons']} - {CERT_DATA['startssl_class3']['name']}
-{CERT_DATA['godaddy_g2_intermediate']['serial_colons']} - {CERT_DATA['godaddy_g2_intermediate']['name']}
-{CERT_DATA['digicert_ev_root']['serial_colons']} - {CERT_DATA['digicert_ev_root']['name']}
-{CERT_DATA['digicert_global_root']['serial_colons']} - {CERT_DATA['digicert_global_root']['name']}
-{CERT_DATA['identrust_root_1']['serial_colons']} - {CERT_DATA['identrust_root_1']['name']}
-{CERT_DATA['startssl_root']['serial_colons']} - {CERT_DATA['startssl_root']['name']}
-{CERT_DATA['godaddy_g2_root']['serial_colons']} - {CERT_DATA['godaddy_g2_root']['name']}
-{CERT_DATA['comodo']['serial_colons']} - {CERT_DATA['comodo']['name']}
-""",
+ assert (
+ stdout
+ == f"{CERT_DATA['letsencrypt_x1']['serial_colons']} - {CERT_DATA['letsencrypt_x1']['name']}\n"
+ f"{CERT_DATA['letsencrypt_x3']['serial_colons']} - {CERT_DATA['letsencrypt_x3']['name']}\n"
+ f"{CERT_DATA['dst_root_x3']['serial_colons']} - {CERT_DATA['dst_root_x3']['name']}\n"
+ f"{CERT_DATA['google_g3']['serial_colons']} - {CERT_DATA['google_g3']['name']}\n"
+ f"{CERT_DATA['globalsign_r2_root']['serial_colons']}"
+ f" - {CERT_DATA['globalsign_r2_root']['name']}\n"
+ f"{CERT_DATA['trustid_server_a52']['serial_colons']}"
+ f" - {CERT_DATA['trustid_server_a52']['name']}\n"
+ f"{CERT_DATA['rapidssl_g3']['serial_colons']} - {CERT_DATA['rapidssl_g3']['name']}\n"
+ f"{CERT_DATA['geotrust']['serial_colons']} - {CERT_DATA['geotrust']['name']}\n"
+ f"{CERT_DATA['startssl_class2']['serial_colons']} - {CERT_DATA['startssl_class2']['name']}\n"
+ f"{CERT_DATA['digicert_sha2']['serial_colons']} - {CERT_DATA['digicert_sha2']['name']}\n"
+ f"{CERT_DATA['globalsign_dv']['serial_colons']} - {CERT_DATA['globalsign_dv']['name']}\n"
+ f"{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}\n"
+ f"{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}\n"
+ f"{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}\n"
+ f"{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}\n"
+ f"{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}\n"
+ f"{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}\n"
+ f"{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}\n"
+ f"{CERT_DATA['comodo_ev']['serial_colons']} - {CERT_DATA['comodo_ev']['name']}\n"
+ f"{CERT_DATA['globalsign']['serial_colons']} - {CERT_DATA['globalsign']['name']}\n"
+ f"{CERT_DATA['digicert_ha_intermediate']['serial_colons']}"
+ f" - {CERT_DATA['digicert_ha_intermediate']['name']}\n"
+ f"{CERT_DATA['comodo_dv']['serial_colons']} - {CERT_DATA['comodo_dv']['name']}\n"
+ f"{CERT_DATA['startssl_class3']['serial_colons']} - {CERT_DATA['startssl_class3']['name']}\n"
+ f"{CERT_DATA['godaddy_g2_intermediate']['serial_colons']}"
+ f" - {CERT_DATA['godaddy_g2_intermediate']['name']}\n"
+ f"{CERT_DATA['digicert_ev_root']['serial_colons']} - {CERT_DATA['digicert_ev_root']['name']}\n"
+ f"{CERT_DATA['digicert_global_root']['serial_colons']}"
+ f" - {CERT_DATA['digicert_global_root']['name']}\n"
+ f"{CERT_DATA['identrust_root_1']['serial_colons']} - {CERT_DATA['identrust_root_1']['name']}\n"
+ f"{CERT_DATA['startssl_root']['serial_colons']} - {CERT_DATA['startssl_root']['name']}\n"
+ f"{CERT_DATA['godaddy_g2_root']['serial_colons']} - {CERT_DATA['godaddy_g2_root']['name']}\n"
+ f"{CERT_DATA['comodo']['serial_colons']} - {CERT_DATA['comodo']['name']}\n"
)
- self.assertEqual(stderr, "")
+ assert stderr == ""
def test_no_cas(self) -> None:
"""Test the command if no CAs are defined."""
CertificateAuthority.objects.all().delete()
stdout, stderr = cmd("list_cas")
- self.assertEqual(stdout, "")
- self.assertEqual(stderr, "")
+ assert stdout == ""
+ assert stderr == ""
def test_basic(self) -> None:
"""Basic test of the command."""
stdout, stderr = cmd("list_cas")
self.assertOutput(stdout, EXPECTED)
- self.assertEqual(stderr, "")
+ assert stderr == ""
def test_disabled(self) -> None:
"""Test the command if some CA is disabled."""
@@ -112,7 +116,7 @@ def test_disabled(self) -> None:
stdout, stderr = cmd("list_cas")
self.assertOutput(stdout, EXPECTED, child_state=" (disabled)")
- self.assertEqual(stderr, "")
+ assert stderr == ""
@freeze_time(TIMESTAMPS["everything_valid"])
def test_tree(self) -> None:
@@ -121,18 +125,16 @@ def test_tree(self) -> None:
NOTE: freeze_time b/c we create some fake CA objects and order in the tree depends on validity.
"""
stdout, stderr = cmd("list_cas", tree=True)
- self.assertEqual(
- stdout,
- f"""{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}
-{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}
-{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}
-{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}
-{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}
-{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}
-└───{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}
-""",
+ assert (
+ stdout == f"{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}\n"
+ f"{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}\n"
+ f"{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}\n{
+ CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}\n"
+ f"{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}\n"
+ f"{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}\n"
+ f"└───{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}\n"
)
- self.assertEqual(stderr, "")
+ assert stderr == ""
# manually create Certificate objects
not_after = timezone.now() + timedelta(days=3)
@@ -155,17 +157,15 @@ def test_tree(self) -> None:
)
stdout, stderr = cmd("list_cas", tree=True)
- self.assertEqual(
- stdout,
- f"""{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}
-{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}
-{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}
-{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}
-{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}
-{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}
-│───ch:il:d3 - child3
-│ └───ch:il:d3:.1 - child3.1
-│───ch:il:d4 - child4
-└───{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}
-""",
+ assert (
+ stdout == f"{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}\n"
+ f"{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}\n"
+ f"{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}\n"
+ f"{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}\n"
+ f"{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}\n"
+ f"{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}\n"
+ f"│───ch:il:d3 - child3\n"
+ f"│ └───ch:il:d3:.1 - child3.1\n"
+ f"│───ch:il:d4 - child4\n"
+ f"└───{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}\n"
)
diff --git a/ca/django_ca/tests/commands/test_list_certs.py b/ca/django_ca/tests/commands/test_list_certs.py
index 2e9613721..758af9ca1 100644
--- a/ca/django_ca/tests/commands/test_list_certs.py
+++ b/ca/django_ca/tests/commands/test_list_certs.py
@@ -49,8 +49,8 @@ def assertCerts(self, *certs: Certificate, **kwargs: Any) -> None: # pylint: di
"""Assert that command outputs the given certs."""
stdout, stderr = cmd("list_certs", **kwargs)
sorted_certs = sorted(certs, key=lambda c: (c.not_after, c.cn, c.serial))
- self.assertEqual(stdout, "".join([f"{self._line(c)}\n" for c in sorted_certs]))
- self.assertEqual(stderr, "")
+ assert stdout == "".join([f"{self._line(c)}\n" for c in sorted_certs])
+ assert stderr == ""
@freeze_time(TIMESTAMPS["everything_valid"])
def test_basic(self) -> None:
diff --git a/ca/django_ca/tests/commands/test_notify.py b/ca/django_ca/tests/commands/test_notify.py
index aca3e7512..c8786c890 100644
--- a/ca/django_ca/tests/commands/test_notify.py
+++ b/ca/django_ca/tests/commands/test_notify.py
@@ -37,18 +37,18 @@ class NotifyExpiringCertsTestCase(TestCaseMixin, TestCase):
def test_no_certs(self) -> None:
"""Try notify command when all certs are still valid."""
stdout, stderr = cmd("notify_expiring_certs")
- self.assertEqual(stdout, "")
- self.assertEqual(stderr, "")
- self.assertEqual(len(mail.outbox), 0)
+ assert stdout == ""
+ assert stderr == ""
+ assert len(mail.outbox) == 0
@freeze_time(TIMESTAMPS["ca_certs_expiring"])
def test_no_watchers(self) -> None:
"""Try expiring certs, but with no watchers."""
# certs have no watchers by default, so we get no mails
stdout, stderr = cmd("notify_expiring_certs")
- self.assertEqual(stdout, "")
- self.assertEqual(stderr, "")
- self.assertEqual(len(mail.outbox), 0)
+ assert stdout == ""
+ assert stderr == ""
+ assert len(mail.outbox) == 0
@freeze_time(TIMESTAMPS["ca_certs_expiring"])
def test_one_watcher(self) -> None:
@@ -59,11 +59,11 @@ def test_one_watcher(self) -> None:
timestamp = self.cert.not_after.strftime("%Y-%m-%d")
stdout, stderr = cmd("notify_expiring_certs")
- self.assertEqual(stdout, "")
- self.assertEqual(stderr, "")
- self.assertEqual(len(mail.outbox), 1)
- self.assertEqual(mail.outbox[0].subject, f"Certificate expiration for {self.cert.cn} on {timestamp}")
- self.assertEqual(mail.outbox[0].to, [email])
+ assert stdout == ""
+ assert stderr == ""
+ assert len(mail.outbox) == 1
+ assert mail.outbox[0].subject == f"Certificate expiration for {self.cert.cn} on {timestamp}"
+ assert mail.outbox[0].to == [email]
def test_notification_days(self) -> None:
"""Test that user gets multiple notifications of expiring certs."""
@@ -74,8 +74,8 @@ def test_notification_days(self) -> None:
with freeze_time(self.cert.not_after - timedelta(days=20)) as frozen_time:
for _i in reversed(range(0, 20)):
stdout, stderr = cmd("notify_expiring_certs", days=14)
- self.assertEqual(stdout, "")
- self.assertEqual(stderr, "")
+ assert stdout == ""
+ assert stderr == ""
frozen_time.tick(timedelta(days=1))
- self.assertEqual(len(mail.outbox), 4)
+ assert len(mail.outbox) == 4
diff --git a/ca/django_ca/tests/commands/test_regenerate_ocsp_keys.py b/ca/django_ca/tests/commands/test_regenerate_ocsp_keys.py
index e99675b84..a7a7c0093 100644
--- a/ca/django_ca/tests/commands/test_regenerate_ocsp_keys.py
+++ b/ca/django_ca/tests/commands/test_regenerate_ocsp_keys.py
@@ -75,14 +75,14 @@ def assertKey( # pylint: disable=invalid-name
priv = typing.cast(
CertificateIssuerPrivateKeyTypes, load_der_private_key(read_file(priv_path), password)
)
- self.assertIsInstance(priv, key_type)
+ assert isinstance(priv, key_type)
if isinstance(priv, (dsa.DSAPrivateKey, rsa.RSAPrivateKey)):
- self.assertEqual(priv.key_size, key_size)
+ assert priv.key_size == key_size
if isinstance(priv, ec.EllipticCurvePrivateKey):
- self.assertIsInstance(priv.curve, elliptic_curve)
+ assert isinstance(priv.curve, elliptic_curve)
cert = x509.load_pem_x509_certificate(read_file(cert_path))
- self.assertIsInstance(cert, x509.Certificate)
+ assert isinstance(cert, x509.Certificate)
cert_qs = Certificate.objects.filter(ca=ca).exclude(pk__in=self.existing_certs)
@@ -105,7 +105,7 @@ def assertKey( # pylint: disable=invalid-name
if ad.access_method == AuthorityInformationAccessOID.CA_ISSUERS
),
)
- self.assertEqual(aia, expected_aia)
+ assert aia == expected_aia
return priv, cert
@@ -232,8 +232,8 @@ def test_overwrite(self) -> None:
new_priv, new_cert = self.assertKey(self.cas["root"], excludes=excludes)
# Key/Cert should now be different
- self.assertNotEqual(priv, new_priv)
- self.assertNotEqual(cert, new_cert)
+ assert priv != new_priv
+ assert cert != new_cert
@override_tmpcadir()
def test_wrong_serial(self) -> None:
diff --git a/ca/django_ca/tests/commands/test_resign_cert.py b/ca/django_ca/tests/commands/test_resign_cert.py
index 228261550..ca22fe220 100644
--- a/ca/django_ca/tests/commands/test_resign_cert.py
+++ b/ca/django_ca/tests/commands/test_resign_cert.py
@@ -124,7 +124,7 @@ def test_basic(self) -> None:
new = Certificate.objects.get(pub=stdout)
assert_resigned(self.cert, new)
assert_equal_ext(self.cert, new)
- self.assertIsInstance(new.algorithm, type(self.cert.algorithm))
+ assert isinstance(new.algorithm, type(self.cert.algorithm))
@override_tmpcadir()
def test_dsa_ca_resign(self) -> None:
@@ -136,7 +136,7 @@ def test_dsa_ca_resign(self) -> None:
new = Certificate.objects.get(pub=stdout)
assert_resigned(self.certs["dsa-cert"], new)
assert_equal_ext(self.certs["dsa-cert"], new)
- self.assertIsInstance(new.algorithm, hashes.SHA256)
+ assert isinstance(new.algorithm, hashes.SHA256)
@override_tmpcadir()
def test_all_extensions_certificate(self) -> None:
@@ -148,20 +148,19 @@ def test_all_extensions_certificate(self) -> None:
new = Certificate.objects.get(pub=stdout)
assert_resigned(orig, new)
- self.assertIsInstance(new.algorithm, hashes.SHA256)
+ assert isinstance(new.algorithm, hashes.SHA256)
expected = orig.extensions
actual = new.extensions
- self.assertEqual(
- sorted(expected.values(), key=lambda e: e.oid.dotted_string),
- sorted(actual.values(), key=lambda e: e.oid.dotted_string),
+ assert sorted(expected.values(), key=lambda e: e.oid.dotted_string) == sorted(
+ actual.values(), key=lambda e: e.oid.dotted_string
)
@override_tmpcadir()
def test_test_all_extensions_cert_with_overrides(self) -> None:
"""Test resigning a certificate with adding new extensions."""
- self.assertIsNotNone(self.ca.sign_authority_information_access)
- self.assertIsNotNone(self.ca.sign_crl_distribution_points)
+ assert self.ca.sign_authority_information_access is not None
+ assert self.ca.sign_crl_distribution_points is not None
self.ca.sign_certificate_policies = certificate_policies(
x509.PolicyInformation(
policy_identifier=CertificatePoliciesOID.CPS_QUALIFIER, policy_qualifiers=None
@@ -216,104 +215,91 @@ def test_test_all_extensions_cert_with_overrides(self) -> None:
new = Certificate.objects.get(pub=stdout)
assert_resigned(orig, new)
- self.assertIsInstance(new.algorithm, hashes.SHA256)
+ assert isinstance(new.algorithm, hashes.SHA256)
extensions = new.extensions
# Test Authority Information Access extension
- self.assertEqual(
- extensions[ExtensionOID.AUTHORITY_INFORMATION_ACCESS],
- x509.Extension(
- oid=ExtensionOID.AUTHORITY_INFORMATION_ACCESS,
- critical=False,
- value=x509.AuthorityInformationAccess(
- [
- x509.AccessDescription(
- access_method=AuthorityInformationAccessOID.OCSP,
- access_location=uri("http://ocsp.example.com/1"),
- ),
- x509.AccessDescription(
- access_method=AuthorityInformationAccessOID.OCSP,
- access_location=uri("http://ocsp.example.com/2"),
- ),
- x509.AccessDescription(
- access_method=AuthorityInformationAccessOID.CA_ISSUERS,
- access_location=uri("http://issuer.example.com/1"),
- ),
- x509.AccessDescription(
- access_method=AuthorityInformationAccessOID.CA_ISSUERS,
- access_location=uri("http://issuer.example.com/2"),
- ),
- ]
- ),
+ assert extensions[ExtensionOID.AUTHORITY_INFORMATION_ACCESS] == x509.Extension(
+ oid=ExtensionOID.AUTHORITY_INFORMATION_ACCESS,
+ critical=False,
+ value=x509.AuthorityInformationAccess(
+ [
+ x509.AccessDescription(
+ access_method=AuthorityInformationAccessOID.OCSP,
+ access_location=uri("http://ocsp.example.com/1"),
+ ),
+ x509.AccessDescription(
+ access_method=AuthorityInformationAccessOID.OCSP,
+ access_location=uri("http://ocsp.example.com/2"),
+ ),
+ x509.AccessDescription(
+ access_method=AuthorityInformationAccessOID.CA_ISSUERS,
+ access_location=uri("http://issuer.example.com/1"),
+ ),
+ x509.AccessDescription(
+ access_method=AuthorityInformationAccessOID.CA_ISSUERS,
+ access_location=uri("http://issuer.example.com/2"),
+ ),
+ ]
),
)
# Test Certificate Policies extension
- self.assertEqual(
- extensions[ExtensionOID.CERTIFICATE_POLICIES],
- x509.Extension(
- oid=ExtensionOID.CERTIFICATE_POLICIES,
- critical=False,
- value=x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier("1.2.3"),
- policy_qualifiers=[
- "https://example.com/overwritten/",
- x509.UserNotice(
- notice_reference=None, explicit_text="overwritten user notice text"
- ),
- ],
- )
- ]
- ),
+ assert extensions[ExtensionOID.CERTIFICATE_POLICIES] == x509.Extension(
+ oid=ExtensionOID.CERTIFICATE_POLICIES,
+ critical=False,
+ value=x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier("1.2.3"),
+ policy_qualifiers=[
+ "https://example.com/overwritten/",
+ x509.UserNotice(
+ notice_reference=None, explicit_text="overwritten user notice text"
+ ),
+ ],
+ )
+ ]
),
)
# Test CRL Distribution Points extension
- self.assertEqual(
- extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS],
- self.crl_distribution_points([uri("http://crl.example.com"), uri("http://crl.example.net")]),
+ assert extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS] == self.crl_distribution_points(
+ [uri("http://crl.example.com"), uri("http://crl.example.net")]
)
# Test Extended Key Usage extension
- self.assertEqual(
- extensions[ExtensionOID.EXTENDED_KEY_USAGE],
- extended_key_usage(ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH),
+ assert extensions[ExtensionOID.EXTENDED_KEY_USAGE] == extended_key_usage(
+ ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH
)
# Test Issuer Alternative Name extension
- self.assertEqual(
- extensions[ExtensionOID.ISSUER_ALTERNATIVE_NAME],
- issuer_alternative_name(dns("ian-override.example.com"), uri("http://ian-override.example.com")),
+ assert extensions[ExtensionOID.ISSUER_ALTERNATIVE_NAME] == issuer_alternative_name(
+ dns("ian-override.example.com"), uri("http://ian-override.example.com")
)
# Test KeyUsage extension
- self.assertEqual(
- extensions[ExtensionOID.KEY_USAGE],
- key_usage(key_agreement=True, key_encipherment=True, critical=False),
+ assert extensions[ExtensionOID.KEY_USAGE] == key_usage(
+ key_agreement=True, key_encipherment=True, critical=False
)
# Test OCSP No Check extension
- self.assertEqual(extensions[ExtensionOID.OCSP_NO_CHECK], ocsp_no_check(critical=True))
+ assert extensions[ExtensionOID.OCSP_NO_CHECK] == ocsp_no_check(critical=True)
# Test Subject Alternative Name extension
- self.assertEqual(
- extensions[x509.SubjectAlternativeName.oid],
- subject_alternative_name(dns("override.example.net")),
+ assert extensions[x509.SubjectAlternativeName.oid] == subject_alternative_name(
+ dns("override.example.net")
)
# Test TLSFeature extension
- self.assertEqual(
- extensions[ExtensionOID.TLS_FEATURE], tls_feature(x509.TLSFeatureType.status_request)
- )
+ assert extensions[ExtensionOID.TLS_FEATURE] == tls_feature(x509.TLSFeatureType.status_request)
@override_tmpcadir()
def test_test_no_extensions_cert_with_overrides(self) -> None:
"""Test resigning a certificate with adding new extensions."""
- self.assertIsNotNone(self.ca.sign_authority_information_access)
- self.assertIsNotNone(self.ca.sign_crl_distribution_points)
+ assert self.ca.sign_authority_information_access is not None
+ assert self.ca.sign_crl_distribution_points is not None
self.ca.sign_certificate_policies = certificate_policies(
x509.PolicyInformation(
policy_identifier=CertificatePoliciesOID.CPS_QUALIFIER, policy_qualifiers=None
@@ -361,68 +347,55 @@ def test_test_no_extensions_cert_with_overrides(self) -> None:
new = Certificate.objects.get(pub=stdout)
assert_resigned(orig, new)
- self.assertIsInstance(new.algorithm, hashes.SHA256)
+ assert isinstance(new.algorithm, hashes.SHA256)
extensions = new.extensions
# Test Certificate Policies extension
- self.assertEqual(
- extensions[ExtensionOID.CERTIFICATE_POLICIES],
- certificate_policies(
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier("1.2.3"),
- policy_qualifiers=[
- "https://example.com/overwritten/",
- x509.UserNotice(notice_reference=None, explicit_text="overwritten user notice text"),
- ],
- )
- ),
+ assert extensions[ExtensionOID.CERTIFICATE_POLICIES] == certificate_policies(
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier("1.2.3"),
+ policy_qualifiers=[
+ "https://example.com/overwritten/",
+ x509.UserNotice(notice_reference=None, explicit_text="overwritten user notice text"),
+ ],
+ )
)
# Test CRL Distribution Points extension
- self.assertEqual(
- extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS],
- crl_distribution_points(
- distribution_point([uri("http://crl.example.com"), uri("http://crl.example.net")])
- ),
+ assert extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS] == crl_distribution_points(
+ distribution_point([uri("http://crl.example.com"), uri("http://crl.example.net")])
)
# Test Extended Key Usage extension
- self.assertEqual(
- extensions[ExtensionOID.EXTENDED_KEY_USAGE],
- extended_key_usage(ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH),
+ assert extensions[ExtensionOID.EXTENDED_KEY_USAGE] == extended_key_usage(
+ ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH
)
# Test Issuer Alternative Name extension
- self.assertEqual(
- extensions[ExtensionOID.ISSUER_ALTERNATIVE_NAME],
- issuer_alternative_name(dns("ian-override.example.com"), uri("http://ian-override.example.com")),
+ assert extensions[ExtensionOID.ISSUER_ALTERNATIVE_NAME] == issuer_alternative_name(
+ dns("ian-override.example.com"), uri("http://ian-override.example.com")
)
# Test Key Usage extension
- self.assertEqual(
- extensions[ExtensionOID.KEY_USAGE], key_usage(key_agreement=True, key_encipherment=True)
- )
+ assert extensions[ExtensionOID.KEY_USAGE] == key_usage(key_agreement=True, key_encipherment=True)
# Test OCSP No Check extension
- self.assertEqual(extensions[ExtensionOID.OCSP_NO_CHECK], ocsp_no_check())
+ assert extensions[ExtensionOID.OCSP_NO_CHECK] == ocsp_no_check()
# Test Subject Alternative Name extension
- self.assertEqual(
- extensions[x509.SubjectAlternativeName.oid],
- subject_alternative_name(dns("override.example.net")),
+ assert extensions[x509.SubjectAlternativeName.oid] == subject_alternative_name(
+ dns("override.example.net")
)
# Test TLSFeature extension
- self.assertEqual(
- extensions[ExtensionOID.TLS_FEATURE], tls_feature(x509.TLSFeatureType.status_request)
- )
+ assert extensions[ExtensionOID.TLS_FEATURE] == tls_feature(x509.TLSFeatureType.status_request)
@override_tmpcadir()
def test_test_no_extensions_cert_with_overrides_with_non_default_critical(self) -> None:
"""Test resigning a certificate with adding new extensions with non-default critical values."""
- self.assertIsNotNone(self.ca.sign_authority_information_access)
- self.assertIsNotNone(self.ca.sign_crl_distribution_points)
+ assert self.ca.sign_authority_information_access is not None
+ assert self.ca.sign_crl_distribution_points is not None
self.ca.sign_certificate_policies = certificate_policies(
x509.PolicyInformation(
policy_identifier=CertificatePoliciesOID.CPS_QUALIFIER, policy_qualifiers=None
@@ -469,67 +442,55 @@ def test_test_no_extensions_cert_with_overrides_with_non_default_critical(self)
new = Certificate.objects.get(pub=stdout)
assert_resigned(orig, new)
- self.assertIsInstance(new.algorithm, hashes.SHA256)
+ assert isinstance(new.algorithm, hashes.SHA256)
extensions = new.extensions
# Test Certificate Policies extension
- self.assertEqual(
- extensions[ExtensionOID.CERTIFICATE_POLICIES],
- x509.Extension(
- oid=ExtensionOID.CERTIFICATE_POLICIES,
- critical=True,
- value=x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier("1.2.3"),
- policy_qualifiers=[
- "https://example.com/overwritten/",
- x509.UserNotice(
- notice_reference=None, explicit_text="overwritten user notice text"
- ),
- ],
- )
- ]
- ),
+ assert extensions[ExtensionOID.CERTIFICATE_POLICIES] == x509.Extension(
+ oid=ExtensionOID.CERTIFICATE_POLICIES,
+ critical=True,
+ value=x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier("1.2.3"),
+ policy_qualifiers=[
+ "https://example.com/overwritten/",
+ x509.UserNotice(
+ notice_reference=None, explicit_text="overwritten user notice text"
+ ),
+ ],
+ )
+ ]
),
)
# Test CRL Distribution Points extension
- self.assertEqual(
- extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS],
- self.crl_distribution_points(
- [uri("http://crl.example.com"), uri("http://crl.example.net")], critical=True
- ),
+ assert extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS] == self.crl_distribution_points(
+ [uri("http://crl.example.com"), uri("http://crl.example.net")], critical=True
)
# Test Extended Key Usage extension
- self.assertEqual(
- extensions[ExtensionOID.EXTENDED_KEY_USAGE],
- extended_key_usage(
- ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, critical=True
- ),
+ assert extensions[ExtensionOID.EXTENDED_KEY_USAGE] == extended_key_usage(
+ ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, critical=True
)
# Test Key Usage extension
- self.assertEqual(
- extensions[ExtensionOID.KEY_USAGE],
- key_usage(key_agreement=True, key_encipherment=True, critical=False),
+ assert extensions[ExtensionOID.KEY_USAGE] == key_usage(
+ key_agreement=True, key_encipherment=True, critical=False
)
# Test OCSP No Check extension
- self.assertEqual(extensions[ExtensionOID.OCSP_NO_CHECK], ocsp_no_check(True))
+ assert extensions[ExtensionOID.OCSP_NO_CHECK] == ocsp_no_check(True)
# Test Subject Alternative Name extension
- self.assertEqual(
- extensions[x509.SubjectAlternativeName.oid],
- subject_alternative_name(dns("override.example.net"), critical=True),
+ assert extensions[x509.SubjectAlternativeName.oid] == subject_alternative_name(
+ dns("override.example.net"), critical=True
)
# Test TLSFeature extension
- self.assertEqual(
- extensions[ExtensionOID.TLS_FEATURE],
- tls_feature(x509.TLSFeatureType.status_request, critical=True),
+ assert extensions[ExtensionOID.TLS_FEATURE] == tls_feature(
+ x509.TLSFeatureType.status_request, critical=True
)
@override_tmpcadir()
@@ -542,7 +503,7 @@ def test_custom_algorithm(self) -> None:
new = Certificate.objects.get(pub=stdout)
assert_resigned(self.cert, new)
assert_equal_ext(self.cert, new)
- self.assertIsInstance(new.algorithm, hashes.SHA512)
+ assert isinstance(new.algorithm, hashes.SHA512)
@override_tmpcadir()
def test_different_ca(self) -> None:
@@ -589,31 +550,28 @@ def test_overwrite(self) -> None:
new = Certificate.objects.get(pub=stdout)
assert_resigned(self.cert, new)
- self.assertEqual(new.subject, x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cname)]))
- self.assertEqual(list(new.watchers.all()), [Watcher.objects.get(mail=watcher)])
+ assert new.subject == x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cname)])
+ assert list(new.watchers.all()) == [Watcher.objects.get(mail=watcher)]
# assert overwritten extensions
extensions = new.extensions
# Test Extended Key Usage extension
- self.assertEqual(
- extensions[ExtensionOID.EXTENDED_KEY_USAGE],
- extended_key_usage(ExtendedKeyUsageOID.EMAIL_PROTECTION, critical=True),
+ assert extensions[ExtensionOID.EXTENDED_KEY_USAGE] == extended_key_usage(
+ ExtendedKeyUsageOID.EMAIL_PROTECTION, critical=True
)
# Test Key Usage extension
- self.assertEqual(extensions[ExtensionOID.KEY_USAGE], key_usage(crl_sign=True, critical=False))
+ assert extensions[ExtensionOID.KEY_USAGE] == key_usage(crl_sign=True, critical=False)
# Test Subject Alternative Name extension
- self.assertEqual(
- extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME],
- subject_alternative_name(dns("subject-alternative-name.example.com")),
+ assert extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME] == subject_alternative_name(
+ dns("subject-alternative-name.example.com")
)
# Test TLSFeature extension
- self.assertEqual(
- extensions[ExtensionOID.TLS_FEATURE],
- tls_feature(x509.TLSFeatureType.status_request_v2, critical=True),
+ assert extensions[ExtensionOID.TLS_FEATURE] == tls_feature(
+ x509.TLSFeatureType.status_request_v2, critical=True
)
@override_tmpcadir(
@@ -627,7 +585,7 @@ def test_set_profile(self) -> None:
assert stderr == ""
new = Certificate.objects.get(pub=stdout)
- self.assertEqual(new.not_after.date(), timezone.now().date() + timedelta(days=200))
+ assert new.not_after.date() == timezone.now().date() + timedelta(days=200)
assert_resigned(self.cert, new)
assert_equal_ext(self.cert, new)
@@ -645,7 +603,7 @@ def test_cert_profile(self) -> None:
assert stderr == ""
new = Certificate.objects.get(pub=stdout)
- self.assertEqual(new.not_after.date(), timezone.now().date() + timedelta(days=200))
+ assert new.not_after.date() == timezone.now().date() + timedelta(days=200)
assert_resigned(self.cert, new)
assert_equal_ext(self.cert, new)
diff --git a/ca/django_ca/tests/commands/test_revoke_cert.py b/ca/django_ca/tests/commands/test_revoke_cert.py
index 3b9c4c3a4..a4239a36e 100644
--- a/ca/django_ca/tests/commands/test_revoke_cert.py
+++ b/ca/django_ca/tests/commands/test_revoke_cert.py
@@ -47,24 +47,24 @@ def revoke(
with mock_signal(pre_revoke_cert) as pre, mock_signal(post_revoke_cert) as post:
stdout, stderr = cmd_e2e(["revoke_cert", cert.serial, *arguments])
- self.assertEqual(stdout, "")
- self.assertEqual(stderr, "")
+ assert stdout == ""
+ assert stderr == ""
cert.refresh_from_db()
- self.assertEqual(pre.call_count, 1)
+ assert pre.call_count == 1
self.assertPostRevoke(post, cert)
- self.assertTrue(cert.revoked)
- self.assertTrue(cert.revoked_date is not None)
- self.assertEqual(cert.revoked_reason, reason)
+ assert cert.revoked
+ assert cert.revoked_date is not None
+ assert cert.revoked_reason == reason
def test_no_arguments(self) -> None:
"""Test revoking without a reason."""
- self.assertFalse(self.cert.revoked)
+ assert not self.cert.revoked
self.revoke(self.cert)
def test_with_reason(self) -> None:
"""Test revoking with a reason."""
- self.assertFalse(self.cert.revoked)
+ assert not self.cert.revoked
for reason in ReasonFlags:
self.revoke(self.cert, ["--reason", reason.name], reason=reason.name)
@@ -79,26 +79,26 @@ def test_with_compromised(self) -> None:
"""Test revoking the certificate with a compromised date."""
now = datetime.now(tz=tz.utc)
self.revoke(self.cert, arguments=["--compromised", now.isoformat()])
- self.assertEqual(self.cert.compromised, now)
+ assert self.cert.compromised == now
def test_with_compromised_with_use_tz_is_false(self) -> None:
"""Test revoking the certificate with a compromised date with USE_TZ=False."""
with self.settings(USE_TZ=False):
now = datetime.now(tz=tz.utc)
self.revoke(self.cert, arguments=["--compromised", now.isoformat()])
- self.assertEqual(self.cert.compromised, timezone.make_naive(now))
+ assert self.cert.compromised == timezone.make_naive(now)
def test_revoked(self) -> None:
"""Test revoking a cert that is already revoked."""
- self.assertFalse(self.cert.revoked)
+ assert not self.cert.revoked
with mock_signal(pre_revoke_cert) as pre, mock_signal(post_revoke_cert) as post:
cmd("revoke_cert", self.cert.serial)
cert = Certificate.objects.get(serial=self.cert.serial)
- self.assertEqual(pre.call_count, 1)
+ assert pre.call_count == 1
self.assertPostRevoke(post, cert)
- self.assertEqual(cert.revoked_reason, ReasonFlags.unspecified.name)
+ assert cert.revoked_reason == ReasonFlags.unspecified.name
with (
assert_command_error(rf"^{self.cert.serial}: Certificate is already revoked\.$"),
@@ -106,13 +106,13 @@ def test_revoked(self) -> None:
mock_signal(post_revoke_cert) as post,
):
cmd("revoke_cert", self.cert.serial, reason=ReasonFlags.key_compromise)
- self.assertFalse(pre.called)
- self.assertFalse(post.called)
+ assert not pre.called
+ assert not post.called
cert = Certificate.objects.get(serial=self.cert.serial)
- self.assertTrue(cert.revoked)
- self.assertTrue(cert.revoked_date is not None)
- self.assertEqual(cert.revoked_reason, ReasonFlags.unspecified.name)
+ assert cert.revoked
+ assert cert.revoked_date is not None
+ assert cert.revoked_reason == ReasonFlags.unspecified.name
def test_compromised_with_naive_datetime(self) -> None:
"""Test passing a naive datetime (which is an error)."""
diff --git a/ca/django_ca/tests/conftest.py b/ca/django_ca/tests/conftest.py
index e01d77f24..6345ed3db 100644
--- a/ca/django_ca/tests/conftest.py
+++ b/ca/django_ca/tests/conftest.py
@@ -18,7 +18,6 @@
import importlib.metadata
import os
import sys
-from collections.abc import Iterator
from typing import Any
import coverage
@@ -143,7 +142,7 @@ def user(
@pytest.fixture
-def user_client(user: "User", client: Client) -> Iterator[Client]:
+def user_client(user: "User", client: Client) -> Client:
"""A Django test client logged in as a normal user."""
client.force_login(user) # type: ignore[arg-type] # django-stubs 5.1.0 thinks user is AbstractUser
return client
diff --git a/ca/django_ca/tests/extensions/test_admin_html.py b/ca/django_ca/tests/extensions/test_admin_html.py
index e5c769e9e..558710c2d 100644
--- a/ca/django_ca/tests/extensions/test_admin_html.py
+++ b/ca/django_ca/tests/extensions/test_admin_html.py
@@ -690,7 +690,7 @@ def _set_distribution_point_extension(
def assertAdminHTML(self, name: str, cert: X509CertMixin) -> None: # pylint: disable=invalid-name
"""Assert that the actual extension HTML is equivalent to the expected HTML."""
for oid, ext in cert.extensions.items():
- self.assertIn(oid, self.admin_html[name], (name, oid))
+ assert oid in self.admin_html[name], (name, oid)
admin_html = self.admin_html[name][oid]
admin_html = f'\n
{admin_html}
'
actual = extension_as_admin_html(ext)
diff --git a/ca/django_ca/tests/extensions/test_unknown_extension.py b/ca/django_ca/tests/extensions/test_unknown_extension.py
index a1ebb5ce5..ce4210746 100644
--- a/ca/django_ca/tests/extensions/test_unknown_extension.py
+++ b/ca/django_ca/tests/extensions/test_unknown_extension.py
@@ -17,6 +17,8 @@
from django.test import TestCase
+import pytest
+
from django_ca.extensions import extension_as_text, parse_extension
@@ -39,17 +41,17 @@ def public_bytes(self) -> bytes:
def test_parse_unknown_key(self) -> None:
"""Test exception for parsing an extension with an unsupported key."""
- with self.assertRaisesRegex(ValueError, r"^wrong_key: Unknown extension key\.$"):
+ with pytest.raises(ValueError, match=r"^wrong_key: Unknown extension key\.$"):
parse_extension("wrong_key", {})
def test_no_extension_as_text(self) -> None:
"""Test textualizing an extension that is not an extension type."""
- with self.assertRaisesRegex(TypeError, r"^bytes: Not a cryptography\.x509\.ExtensionType\.$"):
+ with pytest.raises(TypeError, match=r"^bytes: Not a cryptography\.x509\.ExtensionType\.$"):
extension_as_text(b"foo") # type: ignore[arg-type]
def test_unknown_extension_type_as_text(self) -> None:
"""Test textualizing an extension of unknown type."""
- with self.assertRaisesRegex(
- TypeError, r"^UnknownExtensionType \(oid: 1\.2\.3\): Unknown extension type\.$"
+ with pytest.raises(
+ TypeError, match=r"^UnknownExtensionType \(oid: 1\.2\.3\): Unknown extension type\.$"
):
extension_as_text(self.ext_type)
diff --git a/ca/django_ca/tests/extensions/test_utils.py b/ca/django_ca/tests/extensions/test_utils.py
index 979fb6e86..06804b452 100644
--- a/ca/django_ca/tests/extensions/test_utils.py
+++ b/ca/django_ca/tests/extensions/test_utils.py
@@ -27,11 +27,11 @@ class CertificatePoliciesIsSimpleTestCase(TestCase):
def assertIsSimple(self, *policies: x509.PolicyInformation) -> None: # pylint: disable=invalid-name
"""Assert that a Certificate Policies extension with the given policies is simple."""
- self.assertTrue(certificate_policies_is_simple(self.certificate_policy(*policies)))
+ assert certificate_policies_is_simple(self.certificate_policy(*policies))
def assertIsNotSimple(self, *policies: x509.PolicyInformation) -> None: # pylint: disable=invalid-name
"""Assert that a Certificate Policies extension with the given policies is *not* simple."""
- self.assertFalse(certificate_policies_is_simple(self.certificate_policy(*policies)))
+ assert not certificate_policies_is_simple(self.certificate_policy(*policies))
def certificate_policy(self, *policies: x509.PolicyInformation) -> x509.CertificatePolicies:
"""Create a Certificate Policy object from the given policies."""
diff --git a/ca/django_ca/tests/key_backends/hsm/test_backend.py b/ca/django_ca/tests/key_backends/hsm/test_backend.py
index 496ab6ec0..889cca1f9 100644
--- a/ca/django_ca/tests/key_backends/hsm/test_backend.py
+++ b/ca/django_ca/tests/key_backends/hsm/test_backend.py
@@ -35,7 +35,7 @@
def test_session_with_session_read_only_exception(hsm_backend: HSMBackend) -> None:
"""Test exception message when SessionReadOnly() is raised."""
- with pytest.raises(pkcs11.PKCS11Error, match=r"^Attempting to write to a read-only session\.$"):
+ with pytest.raises(pkcs11.PKCS11Error, match=r"^Attempting to write to a read-only session\.$"): # noqa: PT012
with hsm_backend.session(so_pin=None, user_pin=settings.PKCS11_USER_PIN) as session:
with patch.object(session, "get_key", side_effect=pkcs11.SessionReadOnly()):
session.get_key()
@@ -43,7 +43,7 @@ def test_session_with_session_read_only_exception(hsm_backend: HSMBackend) -> No
def test_session_with_unknown_pkcs11_exception(hsm_backend: HSMBackend) -> None:
"""Test exception message when a generic PKCS11 error is raised."""
- with pytest.raises(pkcs11.PKCS11Error, match=r"^Unknown pkcs11 error \(SessionCount\)\.$"):
+ with pytest.raises(pkcs11.PKCS11Error, match=r"^Unknown pkcs11 error \(SessionCount\)\.$"): # noqa: PT012
with hsm_backend.session(so_pin=None, user_pin=settings.PKCS11_USER_PIN) as session:
with patch.object(session, "get_key", side_effect=pkcs11.SessionCount()):
session.get_key()
diff --git a/ca/django_ca/tests/key_backends/hsm/test_models.py b/ca/django_ca/tests/key_backends/hsm/test_models.py
index bd746cbba..e5e947234 100644
--- a/ca/django_ca/tests/key_backends/hsm/test_models.py
+++ b/ca/django_ca/tests/key_backends/hsm/test_models.py
@@ -25,7 +25,7 @@
)
-@pytest.mark.parametrize("so_pin,user_pin", (("so-pin-value", None), (None, "user-pin-value")))
+@pytest.mark.parametrize(("so_pin", "user_pin"), (("so-pin-value", None), (None, "user-pin-value")))
def test_pins(so_pin: Optional[str], user_pin: Optional[str]) -> None:
"""Test valid pin configurations."""
model = HSMUsePrivateKeyOptions(so_pin=so_pin, user_pin=user_pin)
@@ -34,7 +34,7 @@ def test_pins(so_pin: Optional[str], user_pin: Optional[str]) -> None:
@pytest.mark.parametrize(
- "so_pin,user_pin,error",
+ ("so_pin", "user_pin", "error"),
(
(None, None, r"Provide one of so_pin or user_pin\."),
("so-pin-value", "user-pin-value", r"Provide either so_pin or user_pin\."),
@@ -79,6 +79,6 @@ def test_with_no_context(caplog: LogCaptureFixture) -> None:
def test_with_no_backend_in_context(caplog: LogCaptureFixture) -> None:
"""Test creating a Model with loading the pins from the context."""
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message
HSMUsePrivateKeyOptions.model_validate({}, context={"foo": "bar"})
assert "Did not receive backend in context." in caplog.text
diff --git a/ca/django_ca/tests/key_backends/hsm/test_session.py b/ca/django_ca/tests/key_backends/hsm/test_session.py
index c28bb0295..b48a16069 100644
--- a/ca/django_ca/tests/key_backends/hsm/test_session.py
+++ b/ca/django_ca/tests/key_backends/hsm/test_session.py
@@ -34,7 +34,7 @@
@pytest.fixture
-def pool_key(softhsm_token: str) -> Iterator[PoolKeyType]:
+def pool_key(softhsm_token: str) -> PoolKeyType:
"""Minor fixture to return the pool key for the default settings."""
return settings.PKCS11_PATH, softhsm_token, None, settings.PKCS11_USER_PIN
diff --git a/ca/django_ca/tests/key_backends/test_storages.py b/ca/django_ca/tests/key_backends/test_storages.py
index 238adbb66..d8086cce0 100644
--- a/ca/django_ca/tests/key_backends/test_storages.py
+++ b/ca/django_ca/tests/key_backends/test_storages.py
@@ -40,7 +40,7 @@ def test_private_key_options_key_size(key_size: int) -> None:
@pytest.mark.parametrize("key_size", (-2048, -1, 0, 1, 1023, 1025, 2047, 2049, 8191, 8193, 1000, 2000, 3000))
def test_private_key_options_with_invalid_key_size(key_size: int) -> None:
"""Test invalid key sizes for private key options."""
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message
StoragesCreatePrivateKeyOptions(
key_type="RSA", password=None, path=Path("/does/not/exist"), key_size=key_size
)
diff --git a/ca/django_ca/tests/models/test_certificate.py b/ca/django_ca/tests/models/test_certificate.py
index 2c11547cd..a3e14d0e4 100644
--- a/ca/django_ca/tests/models/test_certificate.py
+++ b/ca/django_ca/tests/models/test_certificate.py
@@ -46,7 +46,7 @@ def test_revocation() -> None:
# Never really happens in real life, but should still be checked
cert = Certificate(revoked=False)
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError, match=r"^Certificate is not revoked\.$"):
cert.get_revocation()
@@ -129,7 +129,7 @@ def test_validate_past(root_cert: Certificate) -> None:
root_cert.full_clean()
-@pytest.mark.parametrize("name,algorithm", (("sha256", hashes.SHA256()), ("sha512", hashes.SHA512())))
+@pytest.mark.parametrize(("name", "algorithm"), (("sha256", hashes.SHA256()), ("sha512", hashes.SHA512())))
def test_get_fingerprint(name: str, algorithm: hashes.HashAlgorithm, usable_cert: Certificate) -> None:
"""Test getting the fingerprint value."""
cert_name = usable_cert.test_name # type: ignore[attr-defined]
diff --git a/ca/django_ca/tests/models/test_certificate_authority.py b/ca/django_ca/tests/models/test_certificate_authority.py
index a2748e6a1..f5e9261be 100644
--- a/ca/django_ca/tests/models/test_certificate_authority.py
+++ b/ca/django_ca/tests/models/test_certificate_authority.py
@@ -450,7 +450,7 @@ def test_serial(usable_ca: CertificateAuthority) -> None:
assert usable_ca.serial == CERT_DATA[usable_ca.name].get("serial")
-@pytest.mark.parametrize("name,algorithm", (("sha256", hashes.SHA256()), ("sha512", hashes.SHA512())))
+@pytest.mark.parametrize(("name", "algorithm"), (("sha256", hashes.SHA256()), ("sha512", hashes.SHA512())))
def test_get_fingerprint(name: str, algorithm: hashes.HashAlgorithm, usable_ca: CertificateAuthority) -> None:
"""Test getting the fingerprint value."""
assert usable_ca.get_fingerprint(algorithm) == CERT_DATA[usable_ca.name][name]
diff --git a/ca/django_ca/tests/pydantic/base.py b/ca/django_ca/tests/pydantic/base.py
index 21c7e9b1c..5c9ff1a2c 100644
--- a/ca/django_ca/tests/pydantic/base.py
+++ b/ca/django_ca/tests/pydantic/base.py
@@ -61,7 +61,7 @@ def assert_validation_errors(
expected_errors: ExpectedErrors,
) -> None:
"""Assertion method to test validation errors."""
- with pytest.raises(ValidationError) as ex_info:
+ with pytest.raises(ValidationError) as ex_info: # noqa: PT012
if isinstance(parameters, list):
model_class(parameters) # type: ignore[call-arg] # ruled out with overload
else:
diff --git a/ca/django_ca/tests/pydantic/test_extensions.py b/ca/django_ca/tests/pydantic/test_extensions.py
index eca8ac096..75b263862 100644
--- a/ca/django_ca/tests/pydantic/test_extensions.py
+++ b/ca/django_ca/tests/pydantic/test_extensions.py
@@ -183,7 +183,7 @@ def test_critical_validation() -> None:
@pytest.mark.parametrize(
- "parameters,expected",
+ ("parameters", "expected"),
(
(
{
@@ -212,7 +212,7 @@ def test_access_description_model(parameters: dict[str, Any], expected: x509.Acc
@pytest.mark.parametrize(
- "parameters,expected",
+ ("parameters", "expected"),
(
(
{"full_name": [GENERAL_NAME]},
@@ -252,7 +252,7 @@ def test_distribution_point(parameters: dict[str, Any], expected: x509.Distribut
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
{
@@ -279,7 +279,7 @@ def test_distribution_point_errors(parameters: dict[str, Any], expected_errors:
@pytest.mark.parametrize(
- "parameters,expected",
+ ("parameters", "expected"),
(
(
{"full_name": [GENERAL_NAME]},
@@ -359,9 +359,9 @@ def test_signed_certificate_timestamp(signed_certificate_timestamp_pub: x509.Cer
@pytest.mark.parametrize("critical", (True, False, None))
-@pytest.mark.parametrize("general_names,parsed_general_names", (([GENERAL_NAME], [dns("example.com")]),))
+@pytest.mark.parametrize(("general_names", "parsed_general_names"), (([GENERAL_NAME], [dns("example.com")]),))
@pytest.mark.parametrize(
- "model,extension_type",
+ ("model", "extension_type"),
(
(SubjectAlternativeNameModel, x509.SubjectAlternativeName),
(IssuerAlternativeNameModel, x509.IssuerAlternativeName),
@@ -381,7 +381,7 @@ def test_alternative_name_extensions(
@pytest.mark.parametrize("model", (SubjectAlternativeNameModel, IssuerAlternativeNameModel))
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
([], [("value_error", ("value",), "Value error, value must not be empty")]),
([GENERAL_NAME] * 2, [("value_error", ("value",), re.compile("value must be unique$"))]),
@@ -398,7 +398,7 @@ def test_alternative_name_extensions_errors(
@pytest.mark.parametrize("critical", (False, None))
@pytest.mark.parametrize(
- "parameters,descriptions",
+ ("parameters", "descriptions"),
(
(
[{"access_method": "ocsp", "access_location": GENERAL_NAME}],
@@ -425,7 +425,7 @@ def test_authority_information_access(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
({"value": []}, [("value_error", ("value",), "Value error, value must not be empty")]),
(
@@ -461,7 +461,7 @@ def test_authority_information_access_errors(
@pytest.mark.parametrize("critical", (False, None))
@pytest.mark.parametrize(
- "parameters,extension",
+ ("parameters", "extension"),
(
(
{"key_identifier": b"MTIz"},
@@ -503,7 +503,7 @@ def test_authority_key_identifier(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
{
@@ -566,7 +566,7 @@ def test_authority_key_identifier_errors(parameters: dict[str, Any], expected_er
@pytest.mark.parametrize("critical", (True, False, None))
@pytest.mark.parametrize(
- "parameters,extension",
+ ("parameters", "extension"),
(
({"ca": False, "path_length": None}, x509.BasicConstraints(ca=False, path_length=None)),
({"ca": True, "path_length": None}, x509.BasicConstraints(ca=True, path_length=None)),
@@ -587,7 +587,7 @@ def test_basic_constraints(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
{},
@@ -620,7 +620,7 @@ def test_basic_constraints_errors(parameters: dict[str, Any], expected_errors: E
@pytest.mark.parametrize("critical", (True, False, None))
@pytest.mark.parametrize(
- "parameters,policies",
+ ("parameters", "policies"),
(
(
(
@@ -745,7 +745,7 @@ def test_certificate_policies(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
[],
@@ -807,7 +807,7 @@ def test_certificate_policies_errors(parameters: dict[str, Any], expected_errors
@pytest.mark.parametrize("critical", (True, False, None))
@pytest.mark.parametrize(
- "parameters,distribution_points",
+ ("parameters", "distribution_points"),
DISTRIBUTION_POINTS_PARAMETERS,
)
def test_crl_distribution_points(
@@ -823,7 +823,7 @@ def test_crl_distribution_points(
@pytest.mark.parametrize("model", (CRLDistributionPointsModel, FreshestCRLModel))
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
[],
@@ -851,14 +851,14 @@ def test_distribution_point_extension_errors(
@pytest.mark.parametrize("critical", (False, None))
-@pytest.mark.parametrize("crl_number", [0, 1])
+@pytest.mark.parametrize("crl_number", (0, 1))
def test_crl_number(critical: Optional[bool], crl_number: int) -> None:
"""Test the CRLNumberModel."""
assert_extension_model(CRLNumberModel, crl_number, x509.CRLNumber(crl_number), critical)
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
({"value": -1}, [("greater_than_equal", ("value",), "Input should be greater than or equal to 0")]),
({"value": 0, "critical": True}, [MUST_BE_NON_CRITICAL_ERROR]),
@@ -870,14 +870,14 @@ def test_crl_number_errors(parameters: dict[str, Any], expected_errors: Expected
@pytest.mark.parametrize("critical", (True, None))
-@pytest.mark.parametrize("crl_number", [0, 1, 2])
+@pytest.mark.parametrize("crl_number", (0, 1, 2))
def test_delta_crl_indicator(critical: Optional[bool], crl_number: int) -> None:
"""Test the DeltaCRLModel."""
assert_extension_model(DeltaCRLIndicatorModel, crl_number, x509.DeltaCRLIndicator(crl_number), critical)
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
({"value": -1}, [("greater_than_equal", ("value",), "Input should be greater than or equal to 0")]),
({"value": 0, "critical": False}, [MUST_BE_CRITICAL_ERROR]),
@@ -893,7 +893,7 @@ def test_delta_crl_indicator_errors(
@pytest.mark.parametrize("critical", (True, False, None))
@pytest.mark.parametrize(
- "usages,extension",
+ ("usages", "extension"),
(
(
[ExtendedKeyUsageOID.CLIENT_AUTH.dotted_string, ExtendedKeyUsageOID.SERVER_AUTH.dotted_string],
@@ -916,7 +916,7 @@ def test_extended_key_usage(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
[],
@@ -954,7 +954,7 @@ def test_extended_key_usage_errors(parameters: dict[str, Any], expected_errors:
@pytest.mark.parametrize("critical", (False, None))
-@pytest.mark.parametrize("parameters,distribution_points", DISTRIBUTION_POINTS_PARAMETERS)
+@pytest.mark.parametrize(("parameters", "distribution_points"), DISTRIBUTION_POINTS_PARAMETERS)
def test_freshest_crl(
critical: Optional[bool],
parameters: list[dict[str, Any]],
@@ -977,14 +977,14 @@ def test_freshest_crl_critical_error() -> None:
@pytest.mark.parametrize("critical", (True, None))
-@pytest.mark.parametrize("skip_certs", [0, 1])
+@pytest.mark.parametrize("skip_certs", (0, 1))
def test_inhibit_any_policy(critical: Optional[bool], skip_certs: int) -> None:
"""Test the InhibitAnyPolicyModel."""
assert_extension_model(InhibitAnyPolicyModel, skip_certs, x509.InhibitAnyPolicy(skip_certs), critical)
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
({"value": -1}, [("greater_than_equal", ("value",), "Input should be greater than or equal to 0")]),
({"value": 0, "critical": False}, [MUST_BE_CRITICAL_ERROR]),
@@ -997,7 +997,7 @@ def test_inhibit_any_policy_errors(parameters: dict[str, Any], expected_errors:
@pytest.mark.parametrize("critical", (True, None))
@pytest.mark.parametrize(
- "parameters,issuing_distribution_point",
+ ("parameters", "issuing_distribution_point"),
(
(
{"full_name": [GENERAL_NAME]},
@@ -1023,7 +1023,7 @@ def test_issuing_distribution_point(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
({"value": {}}, [("value_error", ("value",), "Value error, cannot create empty extension")]),
(
@@ -1068,7 +1068,7 @@ def test_issuing_distribution_point_errors(
@pytest.mark.parametrize("critical", (True, False, None))
@pytest.mark.parametrize(
- "parameters,extension",
+ ("parameters", "extension"),
(
(["crl_sign"], key_usage(crl_sign=True).value),
(
@@ -1083,7 +1083,7 @@ def test_key_usage(critical: Optional[bool], parameters: dict[str, bool], extens
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
[],
@@ -1128,7 +1128,7 @@ def test_key_usage_errors(parameters: dict[str, bool], expected_errors: Expected
@pytest.mark.parametrize("critical", (True, False))
@pytest.mark.parametrize(
- "parameters,extension",
+ ("parameters", "extension"),
(
(
{"template_id": NameOID.COMMON_NAME.dotted_string},
@@ -1157,7 +1157,7 @@ def test_ms_certificate_template(
@pytest.mark.parametrize("critical", (True, None))
@pytest.mark.parametrize(
- "parameters,extension",
+ ("parameters", "extension"),
(
(
{"permitted_subtrees": [GENERAL_NAME]},
@@ -1186,7 +1186,7 @@ def test_name_constraints(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
{"value": {}},
@@ -1238,7 +1238,8 @@ def test_name_constraints_errors(parameters: dict[str, bool], expected_errors: E
@pytest.mark.parametrize("critical", (True, None))
@pytest.mark.parametrize(
- "require_explicit_policy,inhibit_policy_mapping", ((0, 0), (1, 1), (0, 5), (5, 0), (None, 0), (0, None))
+ ("require_explicit_policy", "inhibit_policy_mapping"),
+ ((0, 0), (1, 1), (0, 5), (5, 0), (None, 0), (0, None)),
)
def test_policy_constraints(
critical: Optional[bool], require_explicit_policy: int, inhibit_policy_mapping: int
@@ -1255,7 +1256,7 @@ def test_policy_constraints(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
{"value": {"require_explicit_policy": None, "inhibit_policy_mapping": None}},
@@ -1333,7 +1334,7 @@ def test_signed_certificate_timestamps(signed_certificate_timestamps_pub: x509.C
@pytest.mark.parametrize(
- "parameters,extension",
+ ("parameters", "extension"),
(
(
[{"access_method": "ca_repository", "access_location": GENERAL_NAME}],
@@ -1352,7 +1353,7 @@ def test_subject_information_access(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
({"value": []}, [("value_error", ("value",), "Value error, value must not be empty")]),
(
@@ -1392,7 +1393,7 @@ def test_subject_information_access_errors(
@pytest.mark.parametrize(
- "digest,extension",
+ ("digest", "extension"),
(
# (b"123", x509.SubjectKeyIdentifier(b"123")),
(b"kA==", x509.SubjectKeyIdentifier(b"\x90")),
@@ -1416,7 +1417,7 @@ def test_subject_key_identifier_errors() -> None:
@pytest.mark.parametrize(
- "parameters,features",
+ ("parameters", "features"),
(
(["status_request"], [x509.TLSFeatureType.status_request]),
(["OCSPMustStaple"], [x509.TLSFeatureType.status_request]),
@@ -1449,7 +1450,7 @@ def test_tls_feature(
@pytest.mark.parametrize(
- "parameters,expected_errors",
+ ("parameters", "expected_errors"),
(
(
{"value": []},
@@ -1471,7 +1472,7 @@ def test_tls_feature_errors(parameters: dict[str, bool], expected_errors: Expect
@pytest.mark.parametrize(
- "parameters,extension_type",
+ ("parameters", "extension_type"),
(
(
{"value": b"MTIz", "oid": "1.2.3"},
diff --git a/ca/django_ca/tests/pydantic/test_general_name.py b/ca/django_ca/tests/pydantic/test_general_name.py
index 0611e3aba..c285de8f5 100644
--- a/ca/django_ca/tests/pydantic/test_general_name.py
+++ b/ca/django_ca/tests/pydantic/test_general_name.py
@@ -41,7 +41,7 @@ def test_doctests() -> None:
@pytest.mark.parametrize(
- "typ,value,encoded",
+ ("typ", "value", "encoded"),
(
("UTF8", "example", b"\x0c\x07example"),
("UTF8String", "example", b"\x0c\x07example"),
@@ -126,7 +126,7 @@ def test_other_name_octetstring_type_errors() -> None:
@pytest.mark.parametrize(
- "value,match",
+ ("value", "match"),
(
(b"123", r"Value error, could not parse asn1 data: .*"),
(b"\x03\x02\x04P", "3: Unknown otherName type found."),
@@ -145,7 +145,7 @@ def test_othername_general_errors(value: bytes, match: str) -> None:
@pytest.mark.parametrize(
- "parameters,name,discriminated",
+ ("parameters", "name", "discriminated"),
(
({"type": "DNS", "value": "example.com"}, dns("example.com"), str), # 0
({"type": "DNS", "value": "xn--exmple-cua.com"}, dns("xn--exmple-cua.com"), str), # 1
@@ -186,7 +186,7 @@ def test_general_name(parameters: dict[str, Any], name: x509.GeneralName, discri
@pytest.mark.parametrize(
- "typ,value,errors",
+ ("typ", "value", "errors"),
(
("URI", 123, [("string_type", ("value", "str"), "Input should be a valid string")]),
("email", 123, [("string_type", ("value", "str"), "Input should be a valid string")]),
diff --git a/ca/django_ca/tests/pydantic/test_name.py b/ca/django_ca/tests/pydantic/test_name.py
index caa700e39..c50db0102 100644
--- a/ca/django_ca/tests/pydantic/test_name.py
+++ b/ca/django_ca/tests/pydantic/test_name.py
@@ -33,7 +33,7 @@ def test_doctests() -> None:
@pytest.mark.parametrize(
- "parameters,name_attr",
+ ("parameters", "name_attr"),
(
(
{"oid": NameOID.COMMON_NAME.dotted_string, "value": "example.com"},
@@ -81,7 +81,7 @@ def test_name_attribute(parameters: dict[str, Any], name_attr: x509.NameAttribut
@pytest.mark.parametrize(
- "parameters,errors",
+ ("parameters", "errors"),
(
(
{"oid": "foo", "value": "example.com"},
@@ -123,7 +123,7 @@ def test_name_attribute_empty_common_name(oid: Any) -> None:
@pytest.mark.parametrize(
- "serialized,expected",
+ ("serialized", "expected"),
(
([], x509.Name([])),
(
@@ -150,7 +150,7 @@ def test_name(serialized: list[dict[str, Any]], expected: list[x509.NameAttribut
@pytest.mark.parametrize(
- "value,errors",
+ ("value", "errors"),
(
(
[
diff --git a/ca/django_ca/tests/pydantic/test_type_aliases.py b/ca/django_ca/tests/pydantic/test_type_aliases.py
index 3a5813a9e..f733ac21c 100644
--- a/ca/django_ca/tests/pydantic/test_type_aliases.py
+++ b/ca/django_ca/tests/pydantic/test_type_aliases.py
@@ -53,7 +53,7 @@ class SerialModel(BaseModel):
value: Serial
-@pytest.mark.parametrize("name,curve_cls", constants.ELLIPTIC_CURVE_TYPES.items())
+@pytest.mark.parametrize(("name", "curve_cls"), constants.ELLIPTIC_CURVE_TYPES.items())
def test_elliptic_curve(name: str, curve_cls: type[ec.EllipticCurve]) -> None:
"""Test EllipticCurveTypeAliasModel."""
model = EllipticCurveTypeAliasModel(value=name)
@@ -79,11 +79,11 @@ def test_elliptic_curve(name: str, curve_cls: type[ec.EllipticCurve]) -> None:
@pytest.mark.parametrize("value", ("", "wrong", True, 42, ec.SECP224R1))
def test_elliptic_curve_errors(value: str) -> None:
"""Test invalid values for EllipticCurveTypeAliasModel."""
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message
EllipticCurveTypeAliasModel(value=value)
-@pytest.mark.parametrize("name,hash_cls", constants.HASH_ALGORITHM_TYPES.items())
+@pytest.mark.parametrize(("name", "hash_cls"), constants.HASH_ALGORITHM_TYPES.items())
def test_hash_algorithm(name: str, hash_cls: type[hashes.HashAlgorithm]) -> None:
"""Test EllipticCurveTypeAliasModel."""
model = HashAlgorithmTypeAliasModel(value=name)
@@ -110,12 +110,12 @@ def test_hash_algorithm(name: str, hash_cls: type[hashes.HashAlgorithm]) -> None
@pytest.mark.parametrize("hash_obj", (hashes.SM3(), hashes.BLAKE2b(64), hashes.BLAKE2s(32)))
def test_hash_algorithm_unsupported_types(hash_obj: hashes.HashAlgorithm) -> None:
"""Test that unsupported hash algorithm instances throw an error."""
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message
HashAlgorithmTypeAliasModel(value=hash_obj)
@pytest.mark.parametrize(
- "value,encoded",
+ ("value", "encoded"),
(
(b"\xb5\xee\x0e\x01\x10U", "te4OARBV"),
(b"\xb5\xee\x0e\x01\x10U\xaa", "te4OARBVqg=="),
@@ -140,7 +140,7 @@ def test_json_serializable_bytes(value: bytes, encoded: str) -> None:
@pytest.mark.parametrize(
- "value,validated",
+ ("value", "validated"),
(
("a", "A"),
("abc", "ABC"),
@@ -176,5 +176,5 @@ def test_serial(value: str, validated: str) -> None:
)
def test_serial_errors(value: str) -> None:
"""Test invalid values for the Serial type alias."""
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message
SerialModel(value=value)
diff --git a/ca/django_ca/tests/pydantic/test_validators.py b/ca/django_ca/tests/pydantic/test_validators.py
index 96ed2335c..653d1ee74 100644
--- a/ca/django_ca/tests/pydantic/test_validators.py
+++ b/ca/django_ca/tests/pydantic/test_validators.py
@@ -26,8 +26,8 @@ def test_doctests() -> None:
@pytest.mark.parametrize(
- "name,validated",
- [
+ ("name", "validated"),
+ (
("example.com", "example.com"),
("er.tl", "er.tl"),
("exämple.com", "xn--exmple-cua.com"),
@@ -38,7 +38,7 @@ def test_doctests() -> None:
# Examples from Wikipedia:
("ουτοπία.δπθ.gr", "xn--kxae4bafwg.xn--pxaix.gr"),
("bücher.example", "xn--bcher-kva.example"),
- ],
+ ),
)
def test_dns_validator(name: str, validated: str) -> None:
"""Test :py:func:`django_ca.pydantic.validators.dns_validator`."""
@@ -46,8 +46,8 @@ def test_dns_validator(name: str, validated: str) -> None:
@pytest.mark.parametrize(
- "name,error",
- [("example com", "^Invalid domain: example com:"), ("@example.com", r"^Invalid domain: @example.com:")],
+ ("name", "error"),
+ (("example com", "^Invalid domain: example com:"), ("@example.com", r"^Invalid domain: @example.com:")),
)
def test_dns_validator_errors(name: str, error: str) -> None:
"""Test errors for :py:func:`django_ca.pydantic.validators.dns_validator`."""
@@ -56,8 +56,8 @@ def test_dns_validator_errors(name: str, error: str) -> None:
@pytest.mark.parametrize(
- "email,validated",
- [("user@example.com", "user@example.com"), ("user@exämple.com", "user@xn--exmple-cua.com")],
+ ("email", "validated"),
+ (("user@example.com", "user@example.com"), ("user@exämple.com", "user@xn--exmple-cua.com")),
)
def test_email_validator(email: str, validated: str) -> None:
"""Test :py:func:`django_ca.pydantic.validators.email_validator`."""
@@ -65,13 +65,13 @@ def test_email_validator(email: str, validated: str) -> None:
@pytest.mark.parametrize(
- "email,error",
- [
+ ("email", "error"),
+ (
("user@example com", "^Invalid domain: example com"),
("user", "^Invalid email address: user$"),
("example.com", r"^Invalid email address: example\.com$"),
("@example.com", r"^@example.com: node part is empty$"),
- ],
+ ),
)
def test_email_validator_errors(email: str, error: str) -> None:
"""Test errors for :py:func:`django_ca.pydantic.validators.email_validator`."""
@@ -80,8 +80,8 @@ def test_email_validator_errors(email: str, error: str) -> None:
@pytest.mark.parametrize(
- "url,validated",
- [
+ ("url", "validated"),
+ (
("http://example.com", "http://example.com"),
("http://exämple.com", "http://xn--exmple-cua.com"),
("https://www.example.net", "https://www.example.net"),
@@ -91,7 +91,7 @@ def test_email_validator_errors(email: str, error: str) -> None:
("https://www.exämple.net:443", "https://www.xn--exmple-cua.net:443"),
("https://www.example.net:443/", "https://www.example.net:443/"),
("https://www.example.net:443/test", "https://www.example.net:443/test"),
- ],
+ ),
)
def test_url_validator(url: str, validated: str) -> None:
"""Test py:func:`django_ca.pydantic.validators.url_validator`."""
@@ -99,15 +99,15 @@ def test_url_validator(url: str, validated: str) -> None:
@pytest.mark.parametrize(
- "url,error",
- [
+ ("url", "error"),
+ (
("https://example com", "^Invalid domain: example com: "),
("https://example com:80", "^Invalid domain: example com: "),
("example.com", r"^URL requires scheme and network location: example\.com$"),
("https://[abc", r"^Could not parse URL: https://\[abc: "), # urlsplit() raises an error for this
("https://example.com:abc", r"^Invalid port: https://example\.com:abc: "), # reading port...
("https://example.com:-1", r"^Invalid port: https://example\.com:-1: "),
- ],
+ ),
)
def test_url_validator_errors(url: str, error: str) -> None:
"""Test errors for :py:func:`django_ca.pydantic.validators.url_validator`."""
diff --git a/ca/django_ca/tests/test_acme.py b/ca/django_ca/tests/test_acme.py
index 3181198c9..a342c1da8 100644
--- a/ca/django_ca/tests/test_acme.py
+++ b/ca/django_ca/tests/test_acme.py
@@ -35,6 +35,7 @@
from django_ca.acme import validation
from django_ca.acme.constants import IdentifierType, Status
from django_ca.models import AcmeAccount, AcmeAuthorization, AcmeChallenge, AcmeOrder
+from django_ca.tests.base.assertions import assert_count_equal
from django_ca.tests.base.mixins import TestCaseMixin
urlpatterns = [
@@ -61,12 +62,12 @@ class TestConstantsTestCase(TestCase):
def test_status_enum(self) -> None:
"""Test that the Status Enum is equivalent to the main ACME library."""
expected = [*acme.messages.Status.POSSIBLE_NAMES, "expired"]
- self.assertCountEqual(expected, [s.value for s in Status])
+ assert_count_equal(expected, [s.value for s in Status])
def test_identifier_enum(self) -> None:
"""Test that the IdentifierType Enum is equivalent to the main ACME library."""
actual = list(acme.messages.IdentifierType.POSSIBLE_NAMES)
- self.assertCountEqual(actual, [s.value for s in IdentifierType])
+ assert_count_equal(actual, [s.value for s in IdentifierType])
class Dns01ValidationTestCase(TestCaseMixin, TestCase):
@@ -98,7 +99,7 @@ def assertLogMessages( # pylint: disable=invalid-name # unittest standard
if challenge is None:
challenge = self.chall
- self.assertEqual(logcm.output, [self.get_log_message(challenge), *messages])
+ assert logcm.output == [self.get_log_message(challenge), *messages]
def get_log_message(self, chall: AcmeChallenge) -> str:
"""Get the default log message for DNS-01 validation."""
@@ -121,7 +122,7 @@ def mock_response(self, domain: str, *responses: Iterable[bytes]) -> Iterator[mo
# Note: Only assert the first two parameters, as otherwise we'd test dnspython internals
resolve_mock.assert_called_once()
expected = (f"_acme_challenge.{domain}", "TXT")
- self.assertEqual(resolve_mock.call_args_list[0].args[:2], expected)
+ assert resolve_mock.call_args_list[0].args[:2] == expected
@contextmanager
def resolve(self, side_effect: Any) -> Iterator[mock.Mock]:
@@ -136,16 +137,16 @@ def to_txt_record(self, values: Iterable[bytes]) -> TXTBase:
def test_validation(self) -> None:
"""Test successful DNS-01 validation."""
with self.mock_response(self.domain, [self.chall.expected]), self.assertLogMessages():
- self.assertTrue(validation.validate_dns_01(self.chall))
+ assert validation.validate_dns_01(self.chall)
with self.mock_response(self.domain, [self.chall.expected, b"foo"]), self.assertLogMessages():
- self.assertTrue(validation.validate_dns_01(self.chall))
+ assert validation.validate_dns_01(self.chall)
with self.mock_response(self.domain, [b"data"], [self.chall.expected]), self.assertLogMessages():
- self.assertTrue(validation.validate_dns_01(self.chall))
+ assert validation.validate_dns_01(self.chall)
with (
self.mock_response(self.domain, [b"data"], [b"multiple", self.chall.expected]),
self.assertLogMessages(),
):
- self.assertTrue(validation.validate_dns_01(self.chall))
+ assert validation.validate_dns_01(self.chall)
def test_precomputed(self) -> None:
"""Runa test with pre-computed values to test basic behavior."""
@@ -161,33 +162,33 @@ def test_precomputed(self) -> None:
expected = chall.expected
with self.mock_response(self.domain, [chall.expected]), self.assertLogMessages(challenge=chall):
- self.assertTrue(validation.validate_dns_01(chall))
+ assert validation.validate_dns_01(chall)
with self.mock_response(self.domain, [expected, b"foo"]), self.assertLogMessages(challenge=chall):
- self.assertTrue(validation.validate_dns_01(chall))
+ assert validation.validate_dns_01(chall)
with self.mock_response(self.domain, [b"data"], [expected]), self.assertLogMessages(challenge=chall):
- self.assertTrue(validation.validate_dns_01(chall))
+ assert validation.validate_dns_01(chall)
with (
self.mock_response(self.domain, [b"data"], [b"foo", expected]),
self.assertLogMessages(challenge=chall),
):
- self.assertTrue(validation.validate_dns_01(chall))
+ assert validation.validate_dns_01(chall)
def test_wrong_txt_response(self) -> None:
"""Test failing a challenge via the wrong DNS response."""
with self.mock_response(self.domain, [b"foo"]), self.assertLogMessages():
- self.assertFalse(validation.validate_dns_01(self.chall))
+ assert not validation.validate_dns_01(self.chall)
with self.mock_response(self.domain, [b"foo"], [b"bar"]), self.assertLogMessages():
- self.assertFalse(validation.validate_dns_01(self.chall))
+ assert not validation.validate_dns_01(self.chall)
with self.mock_response(self.domain, [b"foo", b"bar"], [b"bar"]), self.assertLogMessages():
- self.assertFalse(validation.validate_dns_01(self.chall))
+ assert not validation.validate_dns_01(self.chall)
def test_dns_exception(self) -> None:
"""Mock resolver throwing a DNS exception."""
with self.resolve(side_effect=dns.exception.DNSException) as resolve, self.assertLogs() as logcm:
- self.assertFalse(validation.validate_dns_01(self.chall))
+ assert not validation.validate_dns_01(self.chall)
resolve.assert_called_once_with(f"_acme_challenge.{self.domain}", "TXT", lifetime=1, search=False)
- self.assertEqual(len(logcm.output), 2)
- self.assertIn("dns.exception.DNSException", logcm.output[1])
+ assert len(logcm.output) == 2
+ assert "dns.exception.DNSException" in logcm.output[1]
def test_nxdomain(self) -> None:
"""Test validating a domain where the record simply does not exist."""
@@ -197,12 +198,12 @@ def test_nxdomain(self) -> None:
f"DEBUG:django_ca.acme.validation:TXT _acme_challenge.{self.domain}: record does not exist."
),
):
- self.assertFalse(validation.validate_dns_01(self.chall))
+ assert not validation.validate_dns_01(self.chall)
resolve.assert_called_once_with(f"_acme_challenge.{self.domain}", "TXT", lifetime=1, search=False)
def test_wrong_acme_challenge(self) -> None:
"""Test passing an ACME challenge of the wrong type."""
- with self.assertRaisesRegex(ValueError, r"^This function can only validate DNS-01 challenges$"):
+ with pytest.raises(ValueError, match=r"^This function can only validate DNS-01 challenges$"):
validation.validate_dns_01(AcmeChallenge(type=AcmeChallenge.TYPE_HTTP_01))
- with self.assertRaisesRegex(ValueError, r"^This function can only validate DNS-01 challenges$"):
+ with pytest.raises(ValueError, match=r"^This function can only validate DNS-01 challenges$"):
validation.validate_dns_01(AcmeChallenge(type=AcmeChallenge.TYPE_TLS_ALPN_01))
diff --git a/ca/django_ca/tests/test_base.py b/ca/django_ca/tests/test_base.py
index a2a31145d..63f1292d8 100644
--- a/ca/django_ca/tests/test_base.py
+++ b/ca/django_ca/tests/test_base.py
@@ -23,6 +23,8 @@
from django.conf import settings
from django.test import TestCase
+import pytest
+
from django_ca.tests.base.assertions import assert_extensions
from django_ca.tests.base.mixins import TestCaseMixin
from django_ca.tests.base.utils import cmd, cmd_e2e, override_tmpcadir
@@ -39,7 +41,7 @@ def test_pragmas(self) -> None:
@override_tmpcadir()
def test_override_tmpcadir(self) -> None:
"""Test override_tmpcadir as decorator."""
- self.assertTrue(settings.CA_DIR.startswith(tempfile.gettempdir()))
+ assert settings.CA_DIR.startswith(tempfile.gettempdir())
@override_tmpcadir()
def test_assert_extensions(self) -> None:
@@ -92,25 +94,25 @@ class OverrideCaDirForFuncTestCase(TestCaseMixin, TestCase):
@override_tmpcadir()
def test_a(self) -> None:
# add three tests to make sure that every test case sees a different dir
- self.assertTrue(settings.CA_DIR.startswith(tempfile.gettempdir()))
- self.assertNotIn(settings.CA_DIR, self.seen_dirs)
+ assert settings.CA_DIR.startswith(tempfile.gettempdir())
+ assert settings.CA_DIR not in self.seen_dirs
self.seen_dirs.add(settings.CA_DIR)
@override_tmpcadir()
def test_b(self) -> None:
- self.assertTrue(settings.CA_DIR.startswith(tempfile.gettempdir()))
- self.assertNotIn(settings.CA_DIR, self.seen_dirs)
+ assert settings.CA_DIR.startswith(tempfile.gettempdir())
+ assert settings.CA_DIR not in self.seen_dirs
self.seen_dirs.add(settings.CA_DIR)
@override_tmpcadir()
def test_c(self) -> None:
- self.assertTrue(settings.CA_DIR.startswith(tempfile.gettempdir()))
- self.assertNotIn(settings.CA_DIR, self.seen_dirs)
+ assert settings.CA_DIR.startswith(tempfile.gettempdir())
+ assert settings.CA_DIR not in self.seen_dirs
self.seen_dirs.add(settings.CA_DIR)
def test_no_classes(self) -> None:
msg = r"^Only functions can use override_tmpcadir\(\)$"
- with self.assertRaisesRegex(ValueError, msg):
+ with pytest.raises(ValueError, match=msg):
@override_tmpcadir()
class Foo: # pylint: disable=missing-class-docstring,unused-variable
@@ -126,8 +128,8 @@ def test_basic(self) -> None:
"""Trivial basic test."""
stdout, stderr = cmd_e2e(["list_cas"])
serial = add_colons(self.ca.serial)
- self.assertEqual(stdout, f"{serial} - {self.ca.name}\n")
- self.assertEqual(stderr, "")
+ assert stdout == f"{serial} - {self.ca.name}\n"
+ assert stderr == ""
class TypingTestCase(TestCaseMixin): # never executed as it's not actually a subclass of TestCase
diff --git a/ca/django_ca/tests/test_checks.py b/ca/django_ca/tests/test_checks.py
index 5d22df5c8..d9b7e222a 100644
--- a/ca/django_ca/tests/test_checks.py
+++ b/ca/django_ca/tests/test_checks.py
@@ -34,11 +34,11 @@ def test_no_cache(self) -> None:
)
with self.settings(CACHES={}):
errors = check_cache([app_config])
- self.assertEqual(errors, [expected])
+ assert errors == [expected]
with self.settings(CACHES={}):
errors = check_cache(None)
- self.assertEqual(errors, [expected])
+ assert errors == [expected]
def test_loc_mem_cache(self) -> None:
"""Test what happens if LocMemCache is used."""
@@ -55,16 +55,16 @@ def test_loc_mem_cache(self) -> None:
}
with self.settings(CACHES=setting):
errors = check_cache([app_config])
- self.assertEqual(errors, [expected])
+ assert errors == [expected]
with self.settings(CACHES=setting):
errors = check_cache(None)
- self.assertEqual(errors, [expected])
+ assert errors == [expected]
def test_django_ca_not_checked(self) -> None:
"""Test that no checks are run if django_ca is not checked."""
app_config = apps.get_app_config("auth")
errors = check_cache([app_config])
- self.assertEqual(errors, [])
+ assert not errors
def test_redis_cache(self) -> None:
"""Test if redis cache backend is used."""
@@ -76,4 +76,4 @@ def test_redis_cache(self) -> None:
}
with self.settings(CACHES=setting):
errors = check_cache([app_config])
- self.assertEqual(errors, [])
+ assert not errors
diff --git a/ca/django_ca/tests/test_fields.py b/ca/django_ca/tests/test_fields.py
index f095d63d3..c0bfbe917 100644
--- a/ca/django_ca/tests/test_fields.py
+++ b/ca/django_ca/tests/test_fields.py
@@ -11,11 +11,11 @@
# You should have received a copy of the GNU General Public License along with django-ca. If not, see
# .
-# TYPEHINT NOTE: mypy-django typehints assertFieldOutput complete wrong.
-# type: ignore
-
"""Test custom Django form fields."""
+# TYPEHINT NOTE: mypy-django typehints assertFieldOutput completely wrong.
+# mypy: ignore-errors
+
import html
import json
from typing import Any
@@ -69,21 +69,21 @@ def assertRequiredError(self, value) -> None: # pylint: disable=invalid-name
field = self.field_class(required=True)
error_required = [field.error_messages["required"]]
- with self.assertRaises(ValidationError) as context_manager:
+ with pytest.raises(ValidationError) as context_manager:
field.clean(value)
- self.assertEqual(context_manager.exception.messages, error_required)
+ assert context_manager.exception.messages == error_required
@pytest.mark.parametrize("critical", (True, False))
@pytest.mark.parametrize("required", (True, False))
@pytest.mark.parametrize(
- "field_class,extension_type",
+ ("field_class", "extension_type"),
(
(fields.IssuerAlternativeNameField, x509.IssuerAlternativeName),
(fields.SubjectAlternativeNameField, x509.SubjectAlternativeName),
),
)
-@pytest.mark.parametrize("value,general_names", (([SER_D1], [DNS1]), ([SER_D1, SER_D2], [DNS1, DNS2])))
+@pytest.mark.parametrize(("value", "general_names"), (([SER_D1], [DNS1]), ([SER_D1, SER_D2], [DNS1, DNS2])))
def test_alternative_name_fields(
critical: bool,
required: bool,
@@ -101,14 +101,14 @@ def test_alternative_name_fields(
@pytest.mark.parametrize("critical", (True, False))
@pytest.mark.parametrize("required", (True, False))
@pytest.mark.parametrize(
- "field_class,extension_type",
+ ("field_class", "extension_type"),
(
(fields.CRLDistributionPointField, x509.CRLDistributionPoints),
(fields.FreshestCRLField, x509.FreshestCRL),
),
)
@pytest.mark.parametrize(
- "value,dpoint",
+ ("value", "dpoint"),
(
(([SER_D1], "", [], ()), distribution_point([DNS1])),
(([SER_D1, SER_D2], "", [], ()), (distribution_point([DNS1, DNS2]))),
@@ -169,7 +169,7 @@ def test_distribution_point_fields(
@pytest.mark.parametrize("critical", (True, False))
@pytest.mark.parametrize("required", (True, False))
@pytest.mark.parametrize(
- "invalid,error",
+ ("invalid", "error"),
(
(([SER_D1], f"CN={D1}", [], ()), r"You cannot provide both full_name and relative_name\."),
(
@@ -210,7 +210,7 @@ def test_crl_distribution_points_field_with_empty_input(
# Test how the field is rendered
name = "field-name"
- raw_html = field.widget.render(name, None)
+ raw_html = field.widget.render(name=name, value=None)
assertInHTML(
f'', raw_html
)
@@ -228,8 +228,8 @@ def test_crl_distribution_points_field_rendering() -> None:
field = fields.CRLDistributionPointField()
reasons = frozenset([x509.ReasonFlags.key_compromise, x509.ReasonFlags.certificate_hold])
raw_html = field.widget.render(
- name,
- crl_distribution_points(distribution_point([DNS1], crl_issuer=[DNS2], reasons=reasons)),
+ name=name,
+ value=crl_distribution_points(distribution_point([DNS1], crl_issuer=[DNS2], reasons=reasons)),
)
full_name_value = html.escape(json.dumps([SER_D1]))
@@ -298,7 +298,7 @@ def test_crl_distribution_points_field_rendering_with_multiple_dps() -> None:
@pytest.mark.parametrize("critical", (True, False))
@pytest.mark.parametrize("required", (True, False))
@pytest.mark.parametrize(
- "ser_ca_issuers,ser_ocsp,ca_issuers,ocsp",
+ ("ser_ca_issuers", "ser_ocsp", "ca_issuers", "ocsp"),
(
((SER_D1,), (), (DNS1,), ()),
((), (SER_D2,), (), (DNS2,)),
@@ -324,7 +324,7 @@ def test_authority_information_access_field(
@pytest.mark.parametrize("critical", (True, False)) # make sure that critical flag has no effect
@pytest.mark.parametrize("required", (True, False))
@pytest.mark.parametrize(
- "ser_ca_issuers,ser_ocsp",
+ ("ser_ca_issuers", "ser_ocsp"),
(("", ""), ("[]", "[]"), (None, None)),
)
def test_authority_information_access_field_with_empty_value(
@@ -336,7 +336,7 @@ def test_authority_information_access_field_with_empty_value(
@pytest.mark.parametrize(
- "ser_ca_issuers,ser_ocsp,error",
+ ("ser_ca_issuers", "ser_ocsp", "error"),
(
(({"type": "DNS", "value": "http://example.com"},), (), ""),
(({"type": "IP", "value": "example.com"},), (), "example.com: Could not parse IP address"),
@@ -397,7 +397,7 @@ def test_rendering(self) -> None:
name = "field-name"
field = self.field_class()
- raw_html = field.widget.render(name, None)
+ raw_html = field.widget.render(name=name, value=None)
for choice, text in self.field_class.choices:
self.assertInHTML(f'', raw_html)
@@ -412,7 +412,7 @@ def test_rendering_profiles(self) -> None:
choices = [key_usage_choices[choice] for choice in choices]
ext = key_usage(**{choice: True for choice in choices})
- raw_html = field.widget.render("unused", ext)
+ raw_html = field.widget.render(name="unused", value=ext)
for choice, text in self.field_class.choices:
if choice in choices:
diff --git a/ca/django_ca/tests/test_management_actions.py b/ca/django_ca/tests/test_management_actions.py
index 68942c19c..130f4616e 100644
--- a/ca/django_ca/tests/test_management_actions.py
+++ b/ca/django_ca/tests/test_management_actions.py
@@ -108,12 +108,12 @@ def setUp(self) -> None:
def assertValue(self, namespace: argparse.Namespace, value: Any) -> None: # pylint: disable=invalid-name
"""Assert a given extension value."""
extension = x509.Extension(oid=x509.SubjectAlternativeName.oid, critical=False, value=value)
- self.assertEqual(namespace.alt, extension)
+ assert namespace.alt == extension
def test_basic(self) -> None:
"""Test basic functionality."""
namespace = self.parser.parse_args([])
- self.assertEqual(namespace.alt, None)
+ assert namespace.alt is None
namespace = self.parser.parse_args(["--alt", "example.com"])
self.assertValue(namespace, x509.SubjectAlternativeName([dns("example.com")]))
@@ -138,15 +138,10 @@ def test_add_cps(self) -> None:
oid = "1.2.3"
cps = "http://example.com/cps"
namespace = self.parser.parse_args(["--pi", oid, "--cps", cps])
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[cps]
- )
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[cps])
+ ]
)
def test_add_multiple_cps(self) -> None:
@@ -155,15 +150,12 @@ def test_add_multiple_cps(self) -> None:
cps1 = "http://example.com/cps1"
cps2 = "http://example.com/cps2"
namespace = self.parser.parse_args(["--pi", oid, "--cps", cps1, "--cps", cps2])
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[cps1, cps2]
- )
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[cps1, cps2]
+ )
+ ]
)
def test_add_multiple_cps_to_different_policy_identifiers(self) -> None:
@@ -173,18 +165,15 @@ def test_add_multiple_cps_to_different_policy_identifiers(self) -> None:
cps1 = "http://example.com/cps1"
cps2 = "http://example.com/cps2"
namespace = self.parser.parse_args(["--pi", oid1, "--cps", cps1, "--pi", oid2, "--cps", cps2])
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid1), policy_qualifiers=[cps1]
- ),
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid2), policy_qualifiers=[cps2]
- ),
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier(oid1), policy_qualifiers=[cps1]
+ ),
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier(oid2), policy_qualifiers=[cps2]
+ ),
+ ]
)
def test_missing_policy_identifier(self) -> None:
@@ -217,21 +206,21 @@ def setUp(self) -> None:
def test_basic(self) -> None:
"""Test basic functionality of action."""
namespace = self.parser.parse_args([])
- self.assertIsNone(namespace.eku)
+ assert namespace.eku is None
namespace = self.parser.parse_args(["--eku", "clientAuth"])
- self.assertEqual(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), namespace.eku)
+ assert x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]) == namespace.eku
namespace = self.parser.parse_args(["--eku", "clientAuth", "serverAuth"])
- self.assertEqual(
- x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH]),
- namespace.eku,
+ assert (
+ x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH])
+ == namespace.eku
)
def test_dotted_string_value(self) -> None:
"""Test passing a dotted string."""
namespace = self.parser.parse_args(["--eku", "1.3.6.1.5.5.7.3.2"])
- self.assertEqual(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), namespace.eku)
+ assert x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]) == namespace.eku
def test_duplicate_values(self) -> None:
"""Test wrong option values."""
@@ -265,13 +254,10 @@ def test_policy_identifier(self) -> None:
"""Basic test for adding a policy identifier."""
oid = "1.2.3"
namespace = self.parser.parse_args(["--pi", oid])
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[])
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[])
+ ]
)
def test_multiple_policy_identifiers(self) -> None:
@@ -279,18 +265,11 @@ def test_multiple_policy_identifiers(self) -> None:
oid1 = "1.2.3"
oid2 = "1.2.4"
namespace = self.parser.parse_args(["--pi", oid1, "--pi", oid2])
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid1), policy_qualifiers=[]
- ),
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid2), policy_qualifiers=[]
- ),
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid1), policy_qualifiers=[]),
+ x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid2), policy_qualifiers=[]),
+ ]
)
def test_any_policy_value_disallowed(self) -> None:
@@ -309,15 +288,12 @@ def test_any_policy_value(self) -> None:
oid = "anyPolicy"
namespace = parser.parse_args(["--pi", oid])
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier("2.5.29.32.0"), policy_qualifiers=[]
- )
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier("2.5.29.32.0"), policy_qualifiers=[]
+ )
+ ]
)
def test_invalid_dotted_string(self) -> None:
@@ -337,16 +313,16 @@ def test_no_min_no_max(self) -> None:
"""Test action with no min/max values."""
parser = argparse.ArgumentParser()
parser.add_argument("--value", action=actions.IntegerRangeAction)
- self.assertEqual(parser.parse_args(["--value=0"]).value, 0)
- self.assertEqual(parser.parse_args(["--value=1"]).value, 1)
- self.assertEqual(parser.parse_args(["--value=-1"]).value, -1)
+ assert parser.parse_args(["--value=0"]).value == 0
+ assert parser.parse_args(["--value=1"]).value == 1
+ assert parser.parse_args(["--value=-1"]).value == -1
def test_min_values(self) -> None:
"""Test the min value for the action."""
self.parser = argparse.ArgumentParser()
self.parser.add_argument("--value", action=actions.IntegerRangeAction, min=0)
- self.assertEqual(self.parser.parse_args(["--value=0"]).value, 0)
- self.assertEqual(self.parser.parse_args(["--value=1"]).value, 1)
+ assert self.parser.parse_args(["--value=0"]).value == 0
+ assert self.parser.parse_args(["--value=1"]).value == 1
assert_parser_error(
self.parser,
["--value=-1"],
@@ -358,8 +334,8 @@ def test_max_values(self) -> None:
"""Test the max value for the action."""
self.parser = argparse.ArgumentParser()
self.parser.add_argument("--value", action=actions.IntegerRangeAction, max=0)
- self.assertEqual(self.parser.parse_args(["--value=0"]).value, 0)
- self.assertEqual(self.parser.parse_args(["--value=-1"]).value, -1)
+ assert self.parser.parse_args(["--value=0"]).value == 0
+ assert self.parser.parse_args(["--value=-1"]).value == -1
assert_parser_error(
self.parser,
["--value=1"],
@@ -379,12 +355,10 @@ def setUp(self) -> None:
def test_basic(self) -> None:
"""Test basic functionality of action."""
namespace = self.parser.parse_args(["--key-usage", "keyCertSign"])
- self.assertEqual(key_usage(key_cert_sign=True, critical=False).value, namespace.key_usage)
+ assert key_usage(key_cert_sign=True, critical=False).value == namespace.key_usage
namespace = self.parser.parse_args(["--key-usage", "keyCertSign", "keyAgreement"])
- self.assertEqual(
- key_usage(key_cert_sign=True, key_agreement=True, critical=False).value, namespace.key_usage
- )
+ assert key_usage(key_cert_sign=True, key_agreement=True, critical=False).value == namespace.key_usage
def test_invalid_values(self) -> None:
"""Test passing invalid values."""
@@ -448,15 +422,15 @@ def setUp(self) -> None:
def test_basic(self) -> None:
"""Test basic functionality of action."""
namespace = self.parser.parse_args(["--tls-feature", "status_request"])
- self.assertEqual(x509.TLSFeature([x509.TLSFeatureType.status_request]), namespace.tls_feature)
+ assert x509.TLSFeature([x509.TLSFeatureType.status_request]) == namespace.tls_feature
namespace = self.parser.parse_args(["--tls-feature", "status_request_v2"])
- self.assertEqual(x509.TLSFeature([x509.TLSFeatureType.status_request_v2]), namespace.tls_feature)
+ assert x509.TLSFeature([x509.TLSFeatureType.status_request_v2]) == namespace.tls_feature
namespace = self.parser.parse_args(["--tls-feature", "status_request", "status_request_v2"])
- self.assertEqual(
- x509.TLSFeature([x509.TLSFeatureType.status_request, x509.TLSFeatureType.status_request_v2]),
- namespace.tls_feature,
+ assert (
+ x509.TLSFeature([x509.TLSFeatureType.status_request, x509.TLSFeatureType.status_request_v2])
+ == namespace.tls_feature
)
def test_error(self) -> None:
@@ -483,16 +457,13 @@ def test_add_notice(self) -> None:
oid = "1.2.3"
notice = "notice text"
namespace = self.parser.parse_args(["--pi", oid, "--notice", notice])
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid),
- policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice)],
- )
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier(oid),
+ policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice)],
+ )
+ ]
)
def test_add_multiple_notices(self) -> None:
@@ -501,19 +472,16 @@ def test_add_multiple_notices(self) -> None:
notice1 = "notice text one"
notice2 = "notice text two"
namespace = self.parser.parse_args(["--pi", oid, "--notice", notice1, "--notice", notice2])
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid),
- policy_qualifiers=[
- x509.UserNotice(notice_reference=None, explicit_text=notice1),
- x509.UserNotice(notice_reference=None, explicit_text=notice2),
- ],
- )
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier(oid),
+ policy_qualifiers=[
+ x509.UserNotice(notice_reference=None, explicit_text=notice1),
+ x509.UserNotice(notice_reference=None, explicit_text=notice2),
+ ],
+ )
+ ]
)
def test_add_multiple_cps_to_different_policy_identifiers(self) -> None:
@@ -525,20 +493,17 @@ def test_add_multiple_cps_to_different_policy_identifiers(self) -> None:
namespace = self.parser.parse_args(
["--pi", oid1, "--notice", notice1, "--pi", oid2, "--notice", notice2]
)
- self.assertEqual(
- namespace.pi,
- x509.CertificatePolicies(
- policies=[
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid1),
- policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice1)],
- ),
- x509.PolicyInformation(
- policy_identifier=x509.ObjectIdentifier(oid2),
- policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice2)],
- ),
- ]
- ),
+ assert namespace.pi == x509.CertificatePolicies(
+ policies=[
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier(oid1),
+ policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice1)],
+ ),
+ x509.PolicyInformation(
+ policy_identifier=x509.ObjectIdentifier(oid2),
+ policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice2)],
+ ),
+ ]
)
def test_missing_policy_identifier(self) -> None:
@@ -571,13 +536,13 @@ def setUp(self) -> None:
def test_basic(self) -> None:
"""Test basic functionality of action."""
args = self.parser.parse_args(["--action=DER"])
- self.assertEqual(args.action, Encoding.DER)
+ assert args.action == Encoding.DER
args = self.parser.parse_args(["--action=ASN1"])
- self.assertEqual(args.action, Encoding.DER)
+ assert args.action == Encoding.DER
args = self.parser.parse_args(["--action=PEM"])
- self.assertEqual(args.action, Encoding.PEM)
+ assert args.action == Encoding.PEM
def test_error(self) -> None:
"""Test false option values."""
@@ -601,13 +566,13 @@ def setUp(self) -> None:
def test_basic(self) -> None:
"""Test basic functionality of action."""
args = self.parser.parse_args(["--curve=sect409k1"])
- self.assertIsInstance(args.curve, ec.SECT409K1)
+ assert isinstance(args.curve, ec.SECT409K1)
args = self.parser.parse_args(["--curve=sect409r1"])
- self.assertIsInstance(args.curve, ec.SECT409R1)
+ assert isinstance(args.curve, ec.SECT409R1)
args = self.parser.parse_args(["--curve=brainpoolP512r1"])
- self.assertIsInstance(args.curve, ec.BrainpoolP512R1)
+ assert isinstance(args.curve, ec.BrainpoolP512R1)
def test_error(self) -> None:
"""Test false option values."""
@@ -633,10 +598,10 @@ def setUp(self) -> None:
def test_basic(self) -> None:
"""Test basic functionality of action."""
args = self.parser.parse_args(["--algo=SHA-256"])
- self.assertIsInstance(args.algo, hashes.SHA256)
+ assert isinstance(args.algo, hashes.SHA256)
args = self.parser.parse_args(["--algo=SHA-512"])
- self.assertIsInstance(args.algo, hashes.SHA512)
+ assert isinstance(args.algo, hashes.SHA512)
def test_error(self) -> None:
"""Test false option values."""
@@ -663,10 +628,10 @@ def setUp(self) -> None:
def test_basic(self) -> None:
"""Test basic functionality of action."""
args = self.parser.parse_args(["--size=2048"])
- self.assertEqual(args.size, 2048)
+ assert args.size == 2048
args = self.parser.parse_args(["--size=4096"])
- self.assertEqual(args.size, 4096)
+ assert args.size == 4096
def test_no_power_two(self) -> None:
"""Test giving values that are not the power of two."""
@@ -708,12 +673,12 @@ def setUp(self) -> None:
def test_none(self) -> None:
"""Test passing no password option at all."""
args = self.parser.parse_args([])
- self.assertIsNone(args.password)
+ assert args.password is None
def test_given(self) -> None:
"""Test giving a password on the command line."""
args = self.parser.parse_args(["--password=foobar"])
- self.assertEqual(args.password, b"foobar")
+ assert args.password == b"foobar"
@mock.patch("getpass.getpass", spec_set=True, return_value="prompted")
def test_output(self, getpass: mock.MagicMock) -> None:
@@ -722,7 +687,7 @@ def test_output(self, getpass: mock.MagicMock) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--password", nargs="?", action=actions.PasswordAction, prompt=prompt)
args = parser.parse_args(["--password"])
- self.assertEqual(args.password, b"prompted")
+ assert args.password == b"prompted"
getpass.assert_called_once_with(prompt=prompt)
@mock.patch("getpass.getpass", spec_set=True, return_value="prompted")
@@ -731,7 +696,7 @@ def test_prompt(self, getpass: mock.MagicMock) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--password", nargs="?", action=actions.PasswordAction)
args = parser.parse_args(["--password"])
- self.assertEqual(args.password, b"prompted")
+ assert args.password == b"prompted"
getpass.assert_called_once()
@@ -750,12 +715,12 @@ def test_basic(self) -> None:
"""Test basic functionality of action."""
for name, cert in self.certs.items():
args = self.parser.parse_args([CERT_DATA[name]["serial"]])
- self.assertEqual(args.cert, cert)
+ assert args.cert == cert
def test_abbreviation(self) -> None:
"""Test using an abbreviation."""
args = self.parser.parse_args([CERT_DATA["root-cert"]["serial"][:6]])
- self.assertEqual(args.cert, self.certs["root-cert"])
+ assert args.cert == self.certs["root-cert"]
def test_missing(self) -> None:
"""Test giving an unknown cert."""
@@ -801,13 +766,13 @@ def test_basic(self) -> None:
"""Test basic functionality of action."""
for name, ca in self.usable_cas:
args = self.parser.parse_args([CERT_DATA[name]["serial"]])
- self.assertEqual(args.ca, ca)
+ assert args.ca == ca
@override_tmpcadir()
def test_abbreviation(self) -> None:
"""Test using an abbreviation."""
args = self.parser.parse_args([CERT_DATA["ec"]["serial"][:6]])
- self.assertEqual(args.ca, self.cas["ec"])
+ assert args.ca == self.cas["ec"]
def test_missing(self) -> None:
"""Test giving an unknown CA."""
@@ -850,7 +815,7 @@ def test_disabled(self) -> None:
parser.add_argument("ca", action=actions.CertificateAuthorityAction, allow_disabled=True)
args = parser.parse_args([self.ca.serial])
- self.assertEqual(args.ca, self.ca)
+ assert args.ca == self.ca
# TODO: re-enable with better checks
# def test_private_key_does_not_exists(self) -> None:
@@ -870,7 +835,7 @@ def test_disabled(self) -> None:
def test_password(self) -> None:
"""Test that the action works with a password-encrypted CA."""
args = self.parser.parse_args([CERT_DATA["pwd"]["serial"]])
- self.assertEqual(args.ca, self.cas["pwd"])
+ assert args.ca == self.cas["pwd"]
class URLActionTestCase(ParserTestCaseMixin, TestCase):
@@ -885,7 +850,7 @@ def test_basic(self) -> None:
"""Test basic functionality of action."""
for url in ["http://example.com", "https://www.example.org"]:
args = self.parser.parse_args([f"--url={url}"])
- self.assertEqual(args.url, url)
+ assert args.url == url
def test_error(self) -> None:
"""Test false option values."""
@@ -908,7 +873,7 @@ def test_basic(self) -> None:
"""Test basic functionality of action."""
expires = timedelta(days=30)
args = self.parser.parse_args(["--expires=30"])
- self.assertEqual(args.expires, expires)
+ assert args.expires == expires
def test_default(self) -> None:
"""Test using the default value."""
@@ -916,7 +881,7 @@ def test_default(self) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--expires", action=actions.ExpiresAction, default=delta)
args = parser.parse_args([])
- self.assertEqual(args.expires, delta)
+ assert args.expires == delta
def test_negative(self) -> None:
"""Test passing a negative value."""
@@ -950,7 +915,7 @@ def setUp(self) -> None:
def test_basic(self) -> None:
"""Test basic functionality of action."""
args = self.parser.parse_args([ReasonFlags.unspecified.name])
- self.assertEqual(args.reason, ReasonFlags.unspecified)
+ assert args.reason == ReasonFlags.unspecified
def test_error(self) -> None:
"""Test false option values."""
@@ -986,17 +951,17 @@ def test_basic(self) -> None:
parser.add_argument("--url", action=actions.MultipleURLAction)
args = parser.parse_args([f"--url={url}"])
- self.assertEqual(args.url, [url])
+ assert args.url == [url]
parser = argparse.ArgumentParser()
parser.add_argument("--url", action=actions.MultipleURLAction)
args = parser.parse_args([f"--url={urls[0]}", f"--url={urls[1]}"])
- self.assertEqual(args.url, urls)
+ assert args.url == urls
def test_none(self) -> None:
"""Test passing no value at all."""
args = self.parser.parse_args([])
- self.assertEqual(args.url, [])
+ assert args.url == []
def test_error(self) -> None:
"""Test false option values."""
diff --git a/ca/django_ca/tests/test_migration_helpers.py b/ca/django_ca/tests/test_migration_helpers.py
index 4bfed2065..7a7adacc7 100644
--- a/ca/django_ca/tests/test_migration_helpers.py
+++ b/ca/django_ca/tests/test_migration_helpers.py
@@ -26,7 +26,7 @@
@pytest.mark.parametrize(
- "crl_url,full_name",
+ ("crl_url", "full_name"),
(
("https://example.com", [uri("https://example.com")]),
(
@@ -53,7 +53,7 @@ def test_0040_crl_url_to_sign_crl_distribution_points(
@pytest.mark.parametrize(
- "issuer_alt_name,general_names",
+ ("issuer_alt_name", "general_names"),
(
("https://example.com", [uri("https://example.com")]),
("URI:https://example.com", [uri("https://example.com")]),
@@ -75,7 +75,7 @@ def test_0040_issuer_alt_name_to_sign_issuer_alternative_name(
@pytest.mark.parametrize(
- "issuer_url,ocsp_url,access_descriptions",
+ ("issuer_url", "ocsp_url", "access_descriptions"),
(
(
"https://issuer.example.com",
@@ -133,7 +133,7 @@ def test_0040_ocsp_url_and_issuer_url_to_sign_authority_information_access(
@pytest.mark.parametrize(
- "distribution_points,crl_url",
+ ("distribution_points", "crl_url"),
(
([distribution_point([uri("https://example.com")])], "https://example.com"),
(
@@ -189,7 +189,7 @@ def test_0040_backwards_sign_crl_distribution_points_to_crl_url(
@pytest.mark.parametrize(
- "issuer_alt_name,general_names",
+ ("issuer_alt_name", "general_names"),
(
("URI:https://example.com", [uri("https://example.com")]),
# issuer_alt_name was a CharField, values where comma-separated.
@@ -210,7 +210,7 @@ def test_0040_backwards_sign_issuer_alternative_name_to_issuer_url(
@pytest.mark.parametrize(
- "issuer_url,ocsp_url,access_descriptions",
+ ("issuer_url", "ocsp_url", "access_descriptions"),
(
(
"https://issuer.example.com",
diff --git a/ca/django_ca/tests/test_models.py b/ca/django_ca/tests/test_models.py
index 318a601d2..32a7f2a9c 100644
--- a/ca/django_ca/tests/test_models.py
+++ b/ca/django_ca/tests/test_models.py
@@ -27,6 +27,7 @@
from django.test import RequestFactory, TestCase, override_settings
from django.utils import timezone
+import pytest
from freezegun import freeze_time
from django_ca.key_backends.storages import StoragesUsePrivateKeyOptions
@@ -57,8 +58,8 @@ def test_from_addr(self) -> None:
name = "Firstname Lastname"
watcher = Watcher.from_addr(f"{name} <{mail}>")
- self.assertEqual(watcher.mail, mail)
- self.assertEqual(watcher.name, name)
+ assert watcher.mail == mail
+ assert watcher.name == name
def test_spaces(self) -> None:
"""Test that ``from_addr() is agnostic to spaces."""
@@ -66,18 +67,18 @@ def test_spaces(self) -> None:
name = "Firstname Lastname"
watcher = Watcher.from_addr(f"{name} <{mail}>")
- self.assertEqual(watcher.mail, mail)
- self.assertEqual(watcher.name, name)
+ assert watcher.mail == mail
+ assert watcher.name == name
watcher = Watcher.from_addr(f"{name}<{mail}>")
- self.assertEqual(watcher.mail, mail)
- self.assertEqual(watcher.name, name)
+ assert watcher.mail == mail
+ assert watcher.name == name
def test_error(self) -> None:
"""Test some validation errors."""
- with self.assertRaises(ValidationError):
+ with pytest.raises(ValidationError):
Watcher.from_addr("foobar ")
- with self.assertRaises(ValidationError):
+ with pytest.raises(ValidationError):
Watcher.from_addr("foobar @")
def test_update(self) -> None:
@@ -88,8 +89,8 @@ def test_update(self) -> None:
Watcher.from_addr(f"{name} <{mail}>")
watcher = Watcher.from_addr(f"{newname} <{mail}>")
- self.assertEqual(watcher.mail, mail)
- self.assertEqual(watcher.name, newname)
+ assert watcher.mail == mail
+ assert watcher.name == newname
def test_str(self) -> None:
"""Test the str function."""
@@ -97,10 +98,10 @@ def test_str(self) -> None:
name = "Firstname Lastname"
watcher = Watcher(mail=mail)
- self.assertEqual(str(watcher), mail)
+ assert str(watcher) == mail
watcher.name = name
- self.assertEqual(str(watcher), f"{name} <{mail}>")
+ assert str(watcher) == f"{name} <{mail}>"
class ModelfieldsTests(TestCaseMixin, TestCase):
@@ -121,14 +122,14 @@ def test_create_pem_bytes(self) -> None:
not_after=timezone.now(),
not_before=timezone.now(),
)
- self.assertEqual(cert.pub, pub)
- self.assertEqual(cert.csr, csr)
+ assert cert.pub == pub
+ assert cert.csr == csr
# Refresh, so that we get lazy values
cert.refresh_from_db()
- self.assertEqual(cert.pub.loaded, self.pub["parsed"])
- self.assertEqual(cert.csr.loaded, self.csr["parsed"])
+ assert cert.pub.loaded == self.pub["parsed"]
+ assert cert.csr.loaded == self.csr["parsed"]
def test_create_bytearray(self) -> None:
"""Test creating with bytes-encoded PEM."""
@@ -141,14 +142,14 @@ def test_create_bytearray(self) -> None:
not_after=timezone.now(),
not_before=timezone.now(),
)
- self.assertEqual(cert.pub, pub)
- self.assertEqual(cert.csr, csr)
+ assert cert.pub == pub
+ assert cert.csr == csr
# Refresh, so that we get lazy values
cert.refresh_from_db()
- self.assertEqual(cert.pub.loaded, self.pub["parsed"])
- self.assertEqual(cert.csr.loaded, self.csr["parsed"])
+ assert cert.pub.loaded == self.pub["parsed"]
+ assert cert.csr.loaded == self.csr["parsed"]
def test_create_memoryview(self) -> None:
"""Test creating with bytes-encoded PEM."""
@@ -161,20 +162,20 @@ def test_create_memoryview(self) -> None:
not_after=timezone.now(),
not_before=timezone.now(),
)
- self.assertEqual(cert.pub, pub)
- self.assertEqual(cert.csr, csr)
+ assert cert.pub == pub
+ assert cert.csr == csr
# Refresh, so that we get lazy values
cert.refresh_from_db()
- self.assertEqual(cert.pub.loaded, self.pub["parsed"])
- self.assertEqual(cert.csr.loaded, self.csr["parsed"])
+ assert cert.pub.loaded == self.pub["parsed"]
+ assert cert.csr.loaded == self.csr["parsed"]
def test_create_from_instance(self) -> None:
"""Test creating a certificate from LazyField instances."""
loaded = self.load_named_cert("root-cert")
- self.assertIsInstance(loaded.pub, LazyCertificate)
- self.assertIsInstance(loaded.csr, LazyCertificateSigningRequest)
+ assert isinstance(loaded.pub, LazyCertificate)
+ assert isinstance(loaded.csr, LazyCertificateSigningRequest)
cert = Certificate.objects.create(
pub=loaded.pub,
csr=loaded.csr,
@@ -182,12 +183,12 @@ def test_create_from_instance(self) -> None:
not_after=timezone.now(),
not_before=timezone.now(),
)
- self.assertEqual(loaded.pub, cert.pub)
- self.assertEqual(loaded.csr, cert.csr)
+ assert loaded.pub == cert.pub
+ assert loaded.csr == cert.csr
reloaded = Certificate.objects.get(pk=cert.pk)
- self.assertEqual(loaded.pub, reloaded.pub)
- self.assertEqual(loaded.csr, reloaded.csr)
+ assert loaded.pub == reloaded.pub
+ assert loaded.csr == reloaded.csr
def test_repr(self) -> None:
"""Test ``repr()`` for custom modelfields."""
@@ -201,8 +202,8 @@ def test_repr(self) -> None:
cert.refresh_from_db()
subject = "CN=root-cert.example.com,OU=Django CA Testsuite,O=Django CA,L=Vienna,ST=Vienna,C=AT"
- self.assertEqual(repr(cert.pub), f"")
- self.assertEqual(repr(cert.csr), "")
+ assert repr(cert.pub) == f""
+ assert repr(cert.csr) == ""
def test_none_value(self) -> None:
"""Test that nullable fields work."""
@@ -213,9 +214,9 @@ def test_none_value(self) -> None:
not_after=timezone.now(),
not_before=timezone.now(),
)
- self.assertIsNone(cert.csr)
+ assert cert.csr is None
cert.refresh_from_db()
- self.assertIsNone(cert.csr)
+ assert cert.csr is None
def test_filter(self) -> None:
"""Test that we can use various representations for filtering."""
@@ -229,8 +230,8 @@ def test_filter(self) -> None:
for prop in ["parsed", "pem", "der"]:
qs = Certificate.objects.filter(pub=self.pub[prop])
- self.assertCountEqual(qs, [cert])
- self.assertEqual(qs[0].pub.der, self.pub["der"])
+ assert list(qs) == [cert]
+ assert qs[0].pub.der == self.pub["der"]
def test_full_clean(self) -> None:
"""Test the full_clean() method, which invokes ``to_python()`` on the field."""
@@ -244,8 +245,8 @@ def test_full_clean(self) -> None:
serial="0",
)
cert.full_clean()
- self.assertEqual(cert.pub.loaded, self.pub["parsed"])
- self.assertEqual(cert.csr.loaded, self.csr["parsed"])
+ assert cert.pub.loaded == self.pub["parsed"]
+ assert cert.csr.loaded == self.csr["parsed"]
cert = Certificate(
pub=cert.pub,
@@ -257,8 +258,8 @@ def test_full_clean(self) -> None:
serial="0",
)
cert.full_clean()
- self.assertEqual(cert.pub.loaded, self.pub["parsed"])
- self.assertEqual(cert.csr.loaded, self.csr["parsed"])
+ assert cert.pub.loaded == self.pub["parsed"]
+ assert cert.csr.loaded == self.csr["parsed"]
def test_empty_csr(self) -> None:
"""Test an empty CSR."""
@@ -272,12 +273,12 @@ def test_empty_csr(self) -> None:
serial="0",
)
cert.full_clean()
- self.assertEqual(cert.pub.loaded, self.pub["parsed"])
- self.assertIsNone(cert.csr)
+ assert cert.pub.loaded == self.pub["parsed"]
+ assert cert.csr is None
def test_invalid_value(self) -> None:
"""Test passing invalid values."""
- with self.assertRaisesRegex(ValueError, r"^True: Could not parse CertificateSigningRequest$"):
+ with pytest.raises(ValueError, match=r"^True: Could not parse CertificateSigningRequest$"):
Certificate.objects.create(
pub=CERT_DATA["child-cert"]["pub"]["parsed"],
csr=True, # type: ignore[misc] # what we test
@@ -286,7 +287,7 @@ def test_invalid_value(self) -> None:
not_before=timezone.now(),
)
- with self.assertRaisesRegex(ValueError, r"^True: Could not parse Certificate$"):
+ with pytest.raises(ValueError, match=r"^True: Could not parse Certificate$"):
Certificate.objects.create(
csr=CERT_DATA["child-cert"]["csr"]["parsed"],
pub=True, # type: ignore[misc] # what we test
@@ -329,56 +330,56 @@ def setUp(self) -> None:
def test_str(self) -> None:
"""Test str() function."""
- self.assertEqual(str(self.account1), "user@example.com")
- self.assertEqual(str(self.account2), "user@example.net")
- self.assertEqual(str(AcmeAccount()), "")
+ assert str(self.account1) == "user@example.com"
+ assert str(self.account2) == "user@example.net"
+ assert str(AcmeAccount()) == ""
def test_serial(self) -> None:
"""Test the ``serial`` property."""
- self.assertEqual(self.account1.serial, self.cas["root"].serial)
- self.assertEqual(self.account2.serial, self.cas["child"].serial)
+ assert self.account1.serial == self.cas["root"].serial
+ assert self.account2.serial == self.cas["child"].serial
# pylint: disable=no-member; false positive: pylint does not detect RelatedObjectDoesNotExist member
- with self.assertRaisesRegex(AcmeAccount.ca.RelatedObjectDoesNotExist, r"^AcmeAccount has no ca\.$"):
+ with pytest.raises(AcmeAccount.ca.RelatedObjectDoesNotExist, match=r"^AcmeAccount has no ca\.$"):
AcmeAccount().serial # noqa: B018
@freeze_time(TIMESTAMPS["everything_valid"])
def test_usable(self) -> None:
"""Test the ``usable`` property."""
- self.assertTrue(self.account1.usable)
- self.assertFalse(self.account2.usable)
+ assert self.account1.usable
+ assert not self.account2.usable
# Try states that make an account **unusable**
self.account1.status = AcmeAccount.STATUS_DEACTIVATED
- self.assertFalse(self.account1.usable)
+ assert not self.account1.usable
self.account1.status = AcmeAccount.STATUS_REVOKED
- self.assertFalse(self.account1.usable)
+ assert not self.account1.usable
# Make the account usable again
self.account1.status = AcmeAccount.STATUS_VALID
- self.assertTrue(self.account1.usable)
+ assert self.account1.usable
# TOS not agreed, but CA does not have any
self.account1.terms_of_service_agreed = False
- self.assertTrue(self.account1.usable)
+ assert self.account1.usable
# TOS not agreed, but CA does have them, so account is now unusable
self.cas["root"].terms_of_service = "http://tos.example.com"
self.cas["root"].save()
- self.assertFalse(self.account1.usable)
+ assert not self.account1.usable
# Make the account usable again
self.account1.terms_of_service_agreed = True
- self.assertTrue(self.account1.usable)
+ assert self.account1.usable
# If the CA is not usable, neither is the account
self.account1.ca.enabled = False
- self.assertFalse(self.account1.usable)
+ assert not self.account1.usable
def test_unique_together(self) -> None:
"""Test that a thumbprint must be unique for the given CA."""
msg = r"^UNIQUE constraint failed: django_ca_acmeaccount\.ca_id, django_ca_acmeaccount\.thumbprint$"
- with transaction.atomic(), self.assertRaisesRegex(IntegrityError, msg):
+ with transaction.atomic(), pytest.raises(IntegrityError, match=msg):
AcmeAccount.objects.create(ca=self.account1.ca, thumbprint=self.account1.thumbprint)
# Works, because CA is different
@@ -390,9 +391,9 @@ def test_set_kid(self) -> None:
hostname = settings.ALLOWED_HOSTS[0]
req = RequestFactory().get("/foobar", HTTP_HOST=hostname)
self.account1.set_kid(req)
- self.assertEqual(
- self.account1.kid,
- f"http://{hostname}/django_ca/acme/{self.account1.serial}/acct/{self.account1.slug}/",
+ assert (
+ self.account1.kid
+ == f"http://{hostname}/django_ca/acme/{self.account1.serial}/acct/{self.account1.slug}/"
)
def test_validate_pem(self) -> None:
@@ -428,35 +429,33 @@ def setUp(self) -> None:
def test_str(self) -> None:
"""Test the str function."""
- self.assertEqual(str(self.order1), f"{self.order1.slug} ({self.account})")
+ assert str(self.order1) == f"{self.order1.slug} ({self.account})"
def test_acme_url(self) -> None:
"""Test the acme url function."""
- self.assertEqual(
- self.order1.acme_url, f"/django_ca/acme/{self.account.ca.serial}/order/{self.order1.slug}/"
- )
+ assert self.order1.acme_url == f"/django_ca/acme/{self.account.ca.serial}/order/{self.order1.slug}/"
def test_acme_finalize_url(self) -> None:
"""Test the acme finalize url function."""
- self.assertEqual(
- self.order1.acme_finalize_url,
- f"/django_ca/acme/{self.account.ca.serial}/order/{self.order1.slug}/finalize/",
+ assert (
+ self.order1.acme_finalize_url
+ == f"/django_ca/acme/{self.account.ca.serial}/order/{self.order1.slug}/finalize/"
)
def test_add_authorizations(self) -> None:
"""Test the add_authorizations method."""
identifier = messages.Identifier(typ=messages.IDENTIFIER_FQDN, value="example.com")
auths = self.order1.add_authorizations([identifier])
- self.assertEqual(auths[0].type, "dns")
- self.assertEqual(auths[0].value, "example.com")
+ assert auths[0].type == "dns"
+ assert auths[0].value == "example.com"
msg = r"^UNIQUE constraint failed: django_ca_acmeauthorization\.order_id, django_ca_acmeauthorization\.type, django_ca_acmeauthorization\.value$" # NOQA: E501
- with transaction.atomic(), self.assertRaisesRegex(IntegrityError, msg):
+ with transaction.atomic(), pytest.raises(IntegrityError, match=msg):
self.order1.add_authorizations([identifier])
def test_serial(self) -> None:
"""Test getting the serial of the associated CA."""
- self.assertEqual(self.order1.serial, self.cas["root"].serial)
+ assert self.order1.serial == self.cas["root"].serial
class AcmeAuthorizationTestCase(TestCaseMixin, AcmeValuesMixin, TestCase):
@@ -484,58 +483,52 @@ def setUp(self) -> None:
def test_str(self) -> None:
"""Test the __str__ method."""
- self.assertEqual(str(self.auth1), "dns: example.com")
- self.assertEqual(str(self.auth2), "dns: example.net")
+ assert str(self.auth1) == "dns: example.com"
+ assert str(self.auth2) == "dns: example.net"
def test_account_property(self) -> None:
"""Test the account property."""
- self.assertEqual(self.auth1.account, self.account)
- self.assertEqual(self.auth2.account, self.account)
+ assert self.auth1.account == self.account
+ assert self.auth2.account == self.account
def test_acme_url(self) -> None:
"""Test acme_url property."""
- self.assertEqual(
- self.auth1.acme_url,
- f"/django_ca/acme/{self.cas['root'].serial}/authz/{self.auth1.slug}/",
- )
- self.assertEqual(
- self.auth2.acme_url,
- f"/django_ca/acme/{self.cas['root'].serial}/authz/{self.auth2.slug}/",
- )
+ assert self.auth1.acme_url == f"/django_ca/acme/{self.cas['root'].serial}/authz/{self.auth1.slug}/"
+ assert self.auth2.acme_url == f"/django_ca/acme/{self.cas['root'].serial}/authz/{self.auth2.slug}/"
def test_expires(self) -> None:
"""Test the expires property."""
- self.assertEqual(self.auth1.expires, self.order.expires)
- self.assertEqual(self.auth2.expires, self.order.expires)
+ assert self.auth1.expires == self.order.expires
+ assert self.auth2.expires == self.order.expires
def test_identifier(self) -> None:
"""Test the identifier property."""
- self.assertEqual(
- self.auth1.identifier, messages.Identifier(typ=messages.IDENTIFIER_FQDN, value=self.auth1.value)
+ assert self.auth1.identifier == messages.Identifier(
+ typ=messages.IDENTIFIER_FQDN, value=self.auth1.value
)
- self.assertEqual(
- self.auth2.identifier, messages.Identifier(typ=messages.IDENTIFIER_FQDN, value=self.auth2.value)
+ assert self.auth2.identifier == messages.Identifier(
+ typ=messages.IDENTIFIER_FQDN, value=self.auth2.value
)
def test_identifier_unknown_type(self) -> None:
"""Test that an identifier with an unknown type raises a ValueError."""
self.auth1.type = "foo"
- with self.assertRaisesRegex(ValueError, r"^Unknown identifier type: foo$"):
+ with pytest.raises(ValueError, match=r"^Unknown identifier type: foo$"):
self.auth1.identifier # noqa: B018
def test_subject_alternative_name(self) -> None:
"""Test the subject_alternative_name property."""
- self.assertEqual(self.auth1.subject_alternative_name, "dns:example.com")
- self.assertEqual(self.auth2.subject_alternative_name, "dns:example.net")
+ assert self.auth1.subject_alternative_name == "dns:example.com"
+ assert self.auth2.subject_alternative_name == "dns:example.net"
def test_get_challenges(self) -> None:
"""Test the get_challenges() method."""
chall_qs = self.auth1.get_challenges()
- self.assertIsInstance(chall_qs[0], AcmeChallenge)
- self.assertIsInstance(chall_qs[1], AcmeChallenge)
+ assert isinstance(chall_qs[0], AcmeChallenge)
+ assert isinstance(chall_qs[1], AcmeChallenge)
- self.assertEqual(self.auth1.get_challenges(), chall_qs)
- self.assertEqual(AcmeChallenge.objects.all().count(), 2)
+ assert self.auth1.get_challenges() == chall_qs
+ assert AcmeChallenge.objects.all().count() == 2
class AcmeChallengeTestCase(TestCaseMixin, AcmeValuesMixin, TestCase):
@@ -563,17 +556,17 @@ def assertChallenge( # pylint: disable=invalid-name
self, challenge: ChallengeTypeVar, typ: str, token: bytes, cls: type[ChallengeTypeVar]
) -> None:
"""Test that the ACME challenge is of the given type."""
- self.assertIsInstance(challenge, cls)
- self.assertEqual(challenge.typ, typ)
- self.assertEqual(challenge.token, token)
+ assert isinstance(challenge, cls)
+ assert challenge.typ == typ
+ assert challenge.token == token
def test_str(self) -> None:
"""Test the __str__ method."""
- self.assertEqual(str(self.chall), f"{self.hostname} ({self.chall.type})")
+ assert str(self.chall) == f"{self.hostname} ({self.chall.type})"
def test_acme_url(self) -> None:
"""Test acme_url property."""
- self.assertEqual(self.chall.acme_url, f"/django_ca/acme/{self.chall.serial}/chall/{self.chall.slug}/")
+ assert self.chall.acme_url == f"/django_ca/acme/{self.chall.serial}/chall/{self.chall.slug}/"
def test_acme_challenge(self) -> None:
"""Test acme_challenge property."""
@@ -590,67 +583,67 @@ def test_acme_challenge(self) -> None:
)
self.chall.type = "foo"
- with self.assertRaisesRegex(ValueError, r"^foo: Unsupported challenge type\.$"):
+ with pytest.raises(ValueError, match=r"^foo: Unsupported challenge type\.$"):
self.chall.acme_challenge # noqa: B018
@freeze_time(TIMESTAMPS["everything_valid"])
def test_acme_validated(self) -> None:
"""Test acme_validated property."""
# preconditions for checks (might change them in setUp without realising it might affect this test)
- self.assertNotEqual(self.chall.status, AcmeChallenge.STATUS_VALID)
- self.assertIsNone(self.chall.validated)
+ assert self.chall.status != AcmeChallenge.STATUS_VALID
+ assert self.chall.validated is None
- self.assertIsNone(self.chall.acme_validated)
+ assert self.chall.acme_validated is None
self.chall.status = AcmeChallenge.STATUS_VALID
- self.assertIsNone(self.chall.acme_validated) # still None (no validated timestamp)
+ assert self.chall.acme_validated is None # still None (no validated timestamp)
self.chall.validated = timezone.now()
- self.assertEqual(self.chall.acme_validated, TIMESTAMPS["everything_valid"])
+ assert self.chall.acme_validated == TIMESTAMPS["everything_valid"]
# We return a UTC timestamp, even if timezone support is disabled.
with self.settings(USE_TZ=False):
self.chall.validated = timezone.now()
- self.assertEqual(self.chall.acme_validated, TIMESTAMPS["everything_valid"])
+ assert self.chall.acme_validated == TIMESTAMPS["everything_valid"]
def test_encoded(self) -> None:
"""Test the encoded property."""
self.chall.token = "ADwFxCAXrnk47rcCnnbbtGYSo_l61MCYXqtBziPt26mk7-QzpYNNKnTsKjbBYPzD"
self.chall.save()
- self.assertEqual(
- self.chall.encoded_token,
- b"QUR3RnhDQVhybms0N3JjQ25uYmJ0R1lTb19sNjFNQ1lYcXRCemlQdDI2bWs3LVF6cFlOTktuVHNLamJCWVB6RA",
+ assert (
+ self.chall.encoded_token
+ == b"QUR3RnhDQVhybms0N3JjQ25uYmJ0R1lTb19sNjFNQ1lYcXRCemlQdDI2bWs3LVF6cFlOTktuVHNLamJCWVB6RA"
)
def test_expected(self) -> None:
"""Test the expected property."""
self.chall.token = "ADwFxCAXrnk47rcCnnbbtGYSo_l61MCYXqtBziPt26mk7-QzpYNNKnTsKjbBYPzD"
self.chall.save()
- self.assertEqual(
- self.chall.expected, self.chall.encoded_token + b"." + self.account.thumbprint.encode("utf-8")
+ assert self.chall.expected == self.chall.encoded_token + b"." + self.account.thumbprint.encode(
+ "utf-8"
)
self.chall.type = AcmeChallenge.TYPE_DNS_01
self.chall.save()
- self.assertEqual(self.chall.expected, b"LoNgngEeuLw4rWDFpplPA0XBp9dd9spzuuqbsRFcKug")
+ assert self.chall.expected == b"LoNgngEeuLw4rWDFpplPA0XBp9dd9spzuuqbsRFcKug"
self.chall.type = AcmeChallenge.TYPE_TLS_ALPN_01
self.chall.save()
- with self.assertRaisesRegex(ValueError, r"^tls-alpn-01: Unsupported challenge type\.$"):
+ with pytest.raises(ValueError, match=r"^tls-alpn-01: Unsupported challenge type\.$"):
self.chall.expected # noqa: B018
def test_get_challenge(self) -> None:
"""Test the get_challenge() function."""
body = self.chall.get_challenge(RequestFactory().get("/"))
- self.assertIsInstance(body, messages.ChallengeBody)
- self.assertEqual(body.chall, self.chall.acme_challenge)
- self.assertEqual(body.status, self.chall.status)
- self.assertEqual(body.validated, self.chall.acme_validated)
- self.assertEqual(body.uri, f"http://testserver{self.chall.acme_url}")
+ assert isinstance(body, messages.ChallengeBody)
+ assert body.chall == self.chall.acme_challenge
+ assert body.status == self.chall.status
+ assert body.validated == self.chall.acme_validated
+ assert body.uri == f"http://testserver{self.chall.acme_url}"
def test_serial(self) -> None:
"""Test the serial property."""
- self.assertEqual(self.chall.serial, self.chall.auth.order.account.ca.serial)
+ assert self.chall.serial == self.chall.auth.order.account.ca.serial
class AcmeCertificateTestCase(TestCaseMixin, AcmeValuesMixin, TestCase):
@@ -673,13 +666,11 @@ def setUp(self) -> None:
def test_acme_url(self) -> None:
"""Test the acme_url property."""
- self.assertEqual(
- self.acme_cert.acme_url, f"/django_ca/acme/{self.order.serial}/cert/{self.acme_cert.slug}/"
- )
+ assert self.acme_cert.acme_url == f"/django_ca/acme/{self.order.serial}/cert/{self.acme_cert.slug}/"
def test_parse_csr(self) -> None:
"""Test the parse_csr property."""
self.acme_cert.csr = (
CERT_DATA["root-cert"]["csr"]["parsed"].public_bytes(Encoding.PEM).decode("utf-8")
)
- self.assertIsInstance(self.acme_cert.parse_csr(), x509.CertificateSigningRequest)
+ assert isinstance(self.acme_cert.parse_csr(), x509.CertificateSigningRequest)
diff --git a/ca/django_ca/tests/test_querysets.py b/ca/django_ca/tests/test_querysets.py
index 9da5c1ce4..276d160ee 100644
--- a/ca/django_ca/tests/test_querysets.py
+++ b/ca/django_ca/tests/test_querysets.py
@@ -31,6 +31,7 @@
Certificate,
CertificateAuthority,
)
+from django_ca.tests.base.assertions import assert_count_equal
from django_ca.tests.base.constants import TIMESTAMPS
from django_ca.tests.base.mixins import AcmeValuesMixin, TestCaseMixin
@@ -42,7 +43,7 @@ def assertQuerySet( # pylint: disable=invalid-name; unittest standard
self, qs: "models.QuerySet[models.Model]", *items: models.Model
) -> None:
"""Minor shortcut to test querysets."""
- self.assertCountEqual(qs, items)
+ assert_count_equal(qs, items)
@contextmanager
def attr(self, obj: models.Model, attr: str, value: Any) -> Iterator[None]:
@@ -67,42 +68,42 @@ def test_enabled_disabled(self) -> None:
"""Test enabled/disabled filter."""
self.load_named_cas("__usable__")
- self.assertCountEqual(CertificateAuthority.objects.enabled(), self.cas.values())
- self.assertCountEqual(CertificateAuthority.objects.disabled(), [])
+ assert_count_equal(CertificateAuthority.objects.enabled(), self.cas.values())
+ assert not CertificateAuthority.objects.disabled()
self.ca.enabled = False
self.ca.save()
- self.assertCountEqual(
+ assert_count_equal(
CertificateAuthority.objects.enabled(),
[c for c in self.cas.values() if c.name != self.ca.name],
)
- self.assertCountEqual(CertificateAuthority.objects.disabled(), [self.ca])
+ assert_count_equal(CertificateAuthority.objects.disabled(), [self.ca])
def test_valid(self) -> None:
"""Test valid/usable/invalid filters."""
self.load_named_cas("__usable__")
with freeze_time(TIMESTAMPS["before_cas"]):
- self.assertCountEqual(CertificateAuthority.objects.valid(), [])
- self.assertCountEqual(CertificateAuthority.objects.usable(), [])
- self.assertCountEqual(CertificateAuthority.objects.invalid(), self.cas.values())
+ assert not CertificateAuthority.objects.valid()
+ assert not CertificateAuthority.objects.usable()
+ assert_count_equal(CertificateAuthority.objects.invalid(), self.cas.values())
with freeze_time(TIMESTAMPS["before_child"]):
valid = [c for c in self.cas.values() if c.name != "child"]
- self.assertCountEqual(CertificateAuthority.objects.valid(), valid)
- self.assertCountEqual(CertificateAuthority.objects.usable(), valid)
- self.assertCountEqual(CertificateAuthority.objects.invalid(), [self.cas["child"]])
+ assert_count_equal(CertificateAuthority.objects.valid(), valid)
+ assert_count_equal(CertificateAuthority.objects.usable(), valid)
+ assert_count_equal(CertificateAuthority.objects.invalid(), [self.cas["child"]])
with freeze_time(TIMESTAMPS["after_child"]):
- self.assertCountEqual(CertificateAuthority.objects.valid(), self.cas.values())
- self.assertCountEqual(CertificateAuthority.objects.usable(), self.cas.values())
- self.assertCountEqual(CertificateAuthority.objects.invalid(), [])
+ assert_count_equal(CertificateAuthority.objects.valid(), self.cas.values())
+ assert_count_equal(CertificateAuthority.objects.usable(), self.cas.values())
+ assert not CertificateAuthority.objects.invalid()
with freeze_time(TIMESTAMPS["cas_expired"]):
- self.assertCountEqual(CertificateAuthority.objects.valid(), [])
- self.assertCountEqual(CertificateAuthority.objects.usable(), [])
- self.assertCountEqual(CertificateAuthority.objects.invalid(), self.cas.values())
+ assert not CertificateAuthority.objects.valid()
+ assert not CertificateAuthority.objects.usable()
+ assert_count_equal(CertificateAuthority.objects.invalid(), self.cas.values())
class CertificateQuerysetTestCase(QuerySetTestCaseMixin, TestCase):
diff --git a/ca/django_ca/tests/test_settings.py b/ca/django_ca/tests/test_settings.py
index c93d900cb..2f18fcee1 100644
--- a/ca/django_ca/tests/test_settings.py
+++ b/ca/django_ca/tests/test_settings.py
@@ -337,7 +337,7 @@ def test_ca_acme_cert_validity_timedelta_settings_as_int(settings: SettingsWrapp
@pytest.mark.parametrize("setting", ("CA_ACME_DEFAULT_CERT_VALIDITY", "CA_ACME_MAX_CERT_VALIDITY"))
@pytest.mark.parametrize(
- "value,message",
+ ("value", "message"),
(
(0.9, "Input should be greater than or equal to 1 day"),
(timedelta(seconds=1), "Input should be greater than or equal to 1 day"),
@@ -356,7 +356,7 @@ def test_ca_acme_cert_validity_limits(
@pytest.mark.parametrize(
- "value,message",
+ ("value", "message"),
(
(timedelta(seconds=59), "Input should be greater than or equal to 1 minute"),
(timedelta(days=2), "Input should be less than or equal to 1 day"),
@@ -414,7 +414,7 @@ def test_ca_crl_profiles_with_deprecated_scope(settings: SettingsWrapper, scope:
@pytest.mark.parametrize(
- "value,parsed",
+ ("value", "parsed"),
(
("0a:bc", "ABC"), # leading zero is stripped
("0", "0"), # single zero is *not* stripped
@@ -481,7 +481,7 @@ def test_ca_default_name_order(settings: SettingsWrapper) -> None:
@pytest.mark.parametrize(
- "value,msg",
+ ("value", "msg"),
(
(True, r"Input should be a valid tuple"),
(("invalid-oid",), "invalid-oid: Invalid object identifier"),
@@ -499,7 +499,7 @@ def test_ca_default_profile_not_defined(settings: SettingsWrapper) -> None:
settings.CA_DEFAULT_PROFILE = "foo"
-@pytest.mark.parametrize("value,expected", (("SHA-224", hashes.SHA224), ("SHA3/384", hashes.SHA3_384)))
+@pytest.mark.parametrize(("value", "expected"), (("SHA-224", hashes.SHA224), ("SHA3/384", hashes.SHA3_384)))
def test_ca_default_signature_hash_algorithm(
settings: SettingsWrapper, value: Any, expected: type[hashes.HashAlgorithm]
) -> None:
@@ -516,7 +516,7 @@ def test_ca_default_signature_hash_algorithm_with_invalid_value(settings: Settin
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
# Serialized version
(
@@ -541,7 +541,7 @@ def test_ca_default_subject(settings: SettingsWrapper, value: Any, expected: x50
@pytest.mark.parametrize(
- "value,msg",
+ ("value", "msg"),
(
((("CN", ""),), r"Value error, Attribute's length must be >= 1 and <= 64, but it was 0"),
((("CN", "X" * 65),), r"Value error, Attribute's length must be >= 1 and <= 64, but it was 65"),
@@ -554,7 +554,7 @@ def test_ca_default_subject_with_invalid_values(settings: SettingsWrapper, value
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
([("CN", "example.com")], x509.Name([cn("example.com")])),
((("C", "AT"), ("CN", "example.com")), x509.Name([country("AT"), cn("example.com")])),
@@ -576,7 +576,7 @@ def test_ca_default_subject_with_deprecated_values(
@pytest.mark.parametrize(
- "value,msg",
+ ("value", "msg"),
(
([("invalid", "wrong")], "invalid: Invalid object identifier"),
([["one-element"]], r"Must be lists/tuples with two items, got 1\."),
@@ -646,7 +646,7 @@ def test_ca_profiles_update_description(settings: SettingsWrapper) -> None:
@pytest.mark.parametrize(
- "subject,expected",
+ ("subject", "expected"),
(
(False, False),
([], x509.Name([])),
@@ -669,7 +669,7 @@ def test_ca_profiles_override_subject_with_deprecated_values(settings: SettingsW
@pytest.mark.parametrize(
- "value,msg",
+ ("value", "msg"),
(
("foo", "Input should be a valid dictionary"), # whole setting is invalid
({"client": {"subject": "foo"}}, r"Value error, foo: Must be a list or tuple\."),
@@ -734,7 +734,7 @@ def test_ca_crl_profiles_invalid_scope(settings: SettingsWrapper) -> None:
@pytest.mark.parametrize(
- "base,override",
+ ("base", "override"),
(
("only_contains_ca_certs", "only_contains_user_certs"),
("only_contains_user_certs", "only_contains_ca_certs"),
diff --git a/ca/django_ca/tests/test_sphinx_extensions.py b/ca/django_ca/tests/test_sphinx_extensions.py
index 8c287149f..ab96138ce 100644
--- a/ca/django_ca/tests/test_sphinx_extensions.py
+++ b/ca/django_ca/tests/test_sphinx_extensions.py
@@ -24,13 +24,13 @@ class CommandLineTextWrapperTestCase(TestCase):
def assertWraps(self, command: str, expected: list[str]) -> None: # pylint: disable=invalid-name
"""Assert that the given command wraps to the expected full text."""
wrapper = CommandLineTextWrapper(width=12)
- self.assertEqual(wrapper.wrap(command), expected)
+ assert wrapper.wrap(command) == expected
def assertSplits(self, command: str, expected: list[str]) -> None: # pylint: disable=invalid-name
"""Assert that the given command splits into the expected tokens."""
wrapper = CommandLineTextWrapper()
# PYLINT note: this is the function that we override
- self.assertEqual(wrapper._split(command), expected) # pylint: disable=protected-access
+ assert wrapper._split(command) == expected # pylint: disable=protected-access
def test_split(self) -> None:
"""Test the overwritten split function."""
diff --git a/ca/django_ca/tests/test_tasks.py b/ca/django_ca/tests/test_tasks.py
index 1a5ebc785..4f7a2e54b 100644
--- a/ca/django_ca/tests/test_tasks.py
+++ b/ca/django_ca/tests/test_tasks.py
@@ -55,12 +55,12 @@ class TestBasic(TestCaseMixin, TestCase):
def test_missing_celery(self) -> None:
"""Test that we work even if celery is not installed."""
# negative assertion to make sure that the IsInstance assertion below is actually meaningful
- self.assertNotIsInstance(tasks.cache_crl, types.FunctionType)
+ assert not isinstance(tasks.cache_crl, types.FunctionType)
try:
with mock.patch.dict("sys.modules", celery=None):
importlib.reload(tasks)
- self.assertIsInstance(tasks.cache_crl, types.FunctionType)
+ assert isinstance(tasks.cache_crl, types.FunctionType)
finally:
# Make sure that module is reloaded, or any failed test in the try block will cause *all other
# tests* to fail, because the celery import would be cached to *not* work
@@ -71,7 +71,7 @@ def test_run_task(self) -> None:
# run_task() without celery
with self.settings(CA_USE_CELERY=False), self.patch("django_ca.tasks.cache_crls") as task_mock:
tasks.run_task(tasks.cache_crls)
- self.assertEqual(task_mock.call_count, 1)
+ assert task_mock.call_count == 1
# finally, run_task() with celery
with self.settings(CA_USE_CELERY=True), self.mute_celery((((), {}), {})):
@@ -117,16 +117,16 @@ def refresh_from_db(self) -> None:
def assertInvalid(self) -> None: # pylint: disable=invalid-name; unittest standard
"""Assert that the challenge validation failed."""
self.refresh_from_db()
- self.assertEqual(self.chall.status, AcmeChallenge.STATUS_INVALID)
- self.assertEqual(self.auth.status, AcmeAuthorization.STATUS_INVALID)
- self.assertEqual(self.order.status, AcmeOrder.STATUS_INVALID)
+ assert self.chall.status == AcmeChallenge.STATUS_INVALID
+ assert self.auth.status == AcmeAuthorization.STATUS_INVALID
+ assert self.order.status == AcmeOrder.STATUS_INVALID
def assertValid(self, order_state: str = AcmeOrder.STATUS_READY) -> None: # pylint: disable=invalid-name
"""Assert that the challenge is valid."""
self.refresh_from_db()
- self.assertEqual(self.chall.status, AcmeChallenge.STATUS_VALID)
- self.assertEqual(self.auth.status, AcmeAuthorization.STATUS_VALID)
- self.assertEqual(self.order.status, order_state)
+ assert self.chall.status == AcmeChallenge.STATUS_VALID
+ assert self.auth.status == AcmeAuthorization.STATUS_VALID
+ assert self.order.status == order_state
@contextmanager
def mock_challenge(
@@ -144,7 +144,7 @@ def test_acme_disabled(self) -> None:
"""Test invoking task when ACME support is not enabled."""
with self.settings(CA_ENABLE_ACME=False), self.assertLogs() as logcm:
tasks.acme_validate_challenge(self.chall.pk)
- self.assertEqual(logcm.output, ["ERROR:django_ca.tasks:ACME is not enabled."])
+ assert logcm.output == ["ERROR:django_ca.tasks:ACME is not enabled."]
def test_unknown_challenge(self) -> None:
"""Test invoking task with an unknown challenge."""
@@ -152,7 +152,7 @@ def test_unknown_challenge(self) -> None:
with self.assertLogs() as logcm:
tasks.acme_validate_challenge(self.chall.pk)
- self.assertEqual(logcm.output, [f"ERROR:django_ca.tasks:Challenge with id={self.chall.pk} not found"])
+ assert logcm.output == [f"ERROR:django_ca.tasks:Challenge with id={self.chall.pk} not found"]
def test_status_not_processing(self) -> None:
"""Test invoking task where the status is not "processing"."""
@@ -162,9 +162,9 @@ def test_status_not_processing(self) -> None:
with self.assertLogs() as logcm:
tasks.acme_validate_challenge(self.chall.pk)
- self.assertEqual(
- logcm.output, [f"ERROR:django_ca.tasks:{self.chall}: pending: Invalid state (must be processing)"]
- )
+ assert logcm.output == [
+ f"ERROR:django_ca.tasks:{self.chall}: pending: Invalid state (must be processing)"
+ ]
def test_unusable_auth(self) -> None:
"""Test invoking task with an unusable authentication."""
@@ -174,7 +174,7 @@ def test_unusable_auth(self) -> None:
with self.assertLogs() as logcm:
tasks.acme_validate_challenge(self.chall.pk)
- self.assertEqual(logcm.output, [f"ERROR:django_ca.tasks:{self.chall}: Authentication is not usable"])
+ assert logcm.output == [f"ERROR:django_ca.tasks:{self.chall}: Authentication is not usable"]
def test_response_wrong_content(self) -> None:
"""Test the server returning the wrong content in the response."""
@@ -184,12 +184,9 @@ def test_response_wrong_content(self) -> None:
):
tasks.acme_validate_challenge(self.chall.pk)
self.assertInvalid()
- self.assertEqual(
- logcm.output,
- [
- f"INFO:django_ca.tasks:{self.chall!s} is invalid",
- ],
- )
+ assert logcm.output == [
+ f"INFO:django_ca.tasks:{self.chall!s} is invalid",
+ ]
def test_unsupported_challenge(self) -> None:
"""Test what happens when challenge type is not supported."""
@@ -202,13 +199,10 @@ def test_unsupported_challenge(self) -> None:
):
tasks.acme_validate_challenge(self.chall.pk)
self.assertInvalid()
- self.assertEqual(
- logcm.output,
- [
- f"ERROR:django_ca.tasks:{self.chall!s}: Challenge type is not supported.",
- f"INFO:django_ca.tasks:{self.chall!s} is invalid",
- ],
- )
+ assert logcm.output == [
+ f"ERROR:django_ca.tasks:{self.chall!s}: Challenge type is not supported.",
+ f"INFO:django_ca.tasks:{self.chall!s} is invalid",
+ ]
def test_basic(self) -> None:
"""Test validation actually working."""
@@ -270,7 +264,7 @@ def mock_challenge(
matcher = req_mock.get(url, raw=HTTPResponse(body=content, status=status, preload_content=False))
yield req_mock
- self.assertEqual(matcher.call_count, call_count)
+ assert matcher.call_count == call_count
def test_response_not_ok(self) -> None:
"""Test the server not returning a HTTP status code 200."""
@@ -284,10 +278,10 @@ def test_request_exception(self) -> None:
with self.patch("requests.get", side_effect=Exception(val)) as req_mock, self.assertLogs() as logcm:
tasks.acme_validate_challenge(self.chall.pk)
self.assertInvalid()
- self.assertEqual(req_mock.mock_calls, [((self.url,), {"timeout": 1, "stream": True})])
- self.assertEqual(len(logcm.output), 2)
- self.assertIn(val, logcm.output[0])
- self.assertEqual(logcm.output[1], f"INFO:django_ca.tasks:{self.chall!s} is invalid")
+ assert req_mock.mock_calls == [((self.url,), {"timeout": 1, "stream": True})]
+ assert len(logcm.output) == 2
+ assert val in logcm.output[0]
+ assert logcm.output[1] == f"INFO:django_ca.tasks:{self.chall!s} is invalid"
@freeze_time(TIMESTAMPS["everything_valid"])
@@ -332,7 +326,7 @@ def mock_challenge(
# Note: Only assert the first two parameters, as otherwise we'd test dnspython internals
resolve_cm.assert_called_once()
expected = (f"_acme_challenge.{domain}", "TXT")
- self.assertEqual(resolve_cm.call_args_list[0].args[:2], expected)
+ assert resolve_cm.call_args_list[0].args[:2] == expected
def test_nxdomain(self) -> None:
"""Test a ACME validation where the domain does not exist."""
@@ -348,14 +342,11 @@ def test_nxdomain(self) -> None:
exp = self.chall.expected.decode("ascii")
acme_domain = f"_acme_challenge.{domain}"
logger = "django_ca.acme.validation"
- self.assertEqual(
- logcm.output,
- [
- f"INFO:{logger}:DNS-01 validation of {domain}: Expect {exp} on {acme_domain}",
- f"DEBUG:{logger}:TXT {acme_domain}: record does not exist.",
- f"INFO:django_ca.tasks:{self.chall!s} is invalid",
- ],
- )
+ assert logcm.output == [
+ f"INFO:{logger}:DNS-01 validation of {domain}: Expect {exp} on {acme_domain}",
+ f"DEBUG:{logger}:TXT {acme_domain}: record does not exist.",
+ f"INFO:django_ca.tasks:{self.chall!s} is invalid",
+ ]
@freeze_time(TIMESTAMPS["everything_valid"])
@@ -385,7 +376,7 @@ def test_acme_disabled(self) -> None:
"""Test invoking task when ACME support is not enabled."""
with self.settings(CA_ENABLE_ACME=False), self.assertLogs() as logcm:
tasks.acme_issue_certificate(self.acme_cert.pk)
- self.assertEqual(logcm.output, ["ERROR:django_ca.tasks:ACME is not enabled."])
+ assert logcm.output == ["ERROR:django_ca.tasks:ACME is not enabled."]
def test_unknown_certificate(self) -> None:
"""Test invoking task with an unknown cert."""
@@ -393,9 +384,7 @@ def test_unknown_certificate(self) -> None:
with self.assertLogs() as logcm:
tasks.acme_issue_certificate(self.acme_cert.pk)
- self.assertEqual(
- logcm.output, [f"ERROR:django_ca.tasks:Certificate with id={self.acme_cert.pk} not found"]
- )
+ assert logcm.output == [f"ERROR:django_ca.tasks:Certificate with id={self.acme_cert.pk} not found"]
def test_unusable_cert(self) -> None:
"""Test invoking task where the order is not usable."""
@@ -405,9 +394,9 @@ def test_unusable_cert(self) -> None:
with self.assertLogs() as logcm:
tasks.acme_issue_certificate(self.acme_cert.pk)
- self.assertEqual(
- logcm.output, [f"ERROR:django_ca.tasks:{self.order}: Cannot issue certificate for this order"]
- )
+ assert logcm.output == [
+ f"ERROR:django_ca.tasks:{self.order}: Cannot issue certificate for this order"
+ ]
@override_tmpcadir()
def test_basic(self) -> None:
@@ -415,22 +404,20 @@ def test_basic(self) -> None:
with self.assertLogs() as logcm:
tasks.acme_issue_certificate(self.acme_cert.pk)
- self.assertEqual(
- logcm.output, [f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"]
- )
+ assert logcm.output == [
+ f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"
+ ]
+
self.acme_cert.refresh_from_db()
assert self.acme_cert.cert is not None, "Check to make mypy happy"
self.order.refresh_from_db()
- self.assertEqual(self.order.status, AcmeOrder.STATUS_VALID)
- self.assertEqual(
- self.acme_cert.cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME],
- subject_alternative_name(x509.DNSName(self.hostname)),
- )
- self.assertEqual(
- self.acme_cert.cert.not_after, timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY
- )
- self.assertEqual(self.acme_cert.cert.cn, self.hostname)
- self.assertEqual(self.acme_cert.cert.profile, model_settings.CA_DEFAULT_PROFILE)
+ assert self.order.status == AcmeOrder.STATUS_VALID
+ assert self.acme_cert.cert.extensions[
+ ExtensionOID.SUBJECT_ALTERNATIVE_NAME
+ ] == subject_alternative_name(x509.DNSName(self.hostname))
+ assert self.acme_cert.cert.not_after == timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY
+ assert self.acme_cert.cert.cn == self.hostname
+ assert self.acme_cert.cert.profile == model_settings.CA_DEFAULT_PROFILE
@override_settings(USE_TZ=False)
def test_basic_without_timezone_support(self) -> None:
@@ -449,15 +436,13 @@ def test_two_hostnames(self) -> None:
self.acme_cert.refresh_from_db()
assert self.acme_cert.cert is not None, "Check to make mypy happy"
self.order.refresh_from_db()
- self.assertEqual(self.order.status, AcmeOrder.STATUS_VALID)
- self.assertEqual(
- self.acme_cert.cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME],
- subject_alternative_name(x509.DNSName(self.hostname), x509.DNSName(hostname2)),
- )
- self.assertEqual(
- self.acme_cert.cert.not_after, timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY
- )
- self.assertIn(self.acme_cert.cert.cn, [self.hostname, hostname2])
+ assert self.order.status == AcmeOrder.STATUS_VALID
+ assert self.acme_cert.cert.extensions[
+ ExtensionOID.SUBJECT_ALTERNATIVE_NAME
+ ] == subject_alternative_name(x509.DNSName(self.hostname), x509.DNSName(hostname2))
+
+ assert self.acme_cert.cert.not_after == timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY
+ assert self.acme_cert.cert.cn in [self.hostname, hostname2]
@override_tmpcadir()
def test_not_after(self) -> None:
@@ -469,19 +454,20 @@ def test_not_after(self) -> None:
with self.assertLogs() as logcm:
tasks.acme_issue_certificate(self.acme_cert.pk)
- self.assertEqual(
- logcm.output, [f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"]
- )
+ assert logcm.output == [
+ f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"
+ ]
+
self.acme_cert.refresh_from_db()
assert self.acme_cert.cert is not None, "Check to make mypy happy"
self.order.refresh_from_db()
- self.assertEqual(self.order.status, AcmeOrder.STATUS_VALID)
- self.assertEqual(
- self.acme_cert.cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME],
- subject_alternative_name(x509.DNSName(self.hostname)),
- )
- self.assertEqual(self.acme_cert.cert.not_after, not_after)
- self.assertEqual(self.acme_cert.cert.cn, self.hostname)
+ assert self.order.status == AcmeOrder.STATUS_VALID
+ assert self.acme_cert.cert.extensions[
+ ExtensionOID.SUBJECT_ALTERNATIVE_NAME
+ ] == subject_alternative_name(x509.DNSName(self.hostname))
+
+ assert self.acme_cert.cert.not_after == not_after
+ assert self.acme_cert.cert.cn == self.hostname
def test_not_after_with_use_tz_is_false(self) -> None:
"""Test not_after with USE_TZ=False."""
@@ -498,22 +484,22 @@ def test_profile(self) -> None:
with self.assertLogs() as logcm:
tasks.acme_issue_certificate(self.acme_cert.pk)
- self.assertEqual(
- logcm.output, [f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"]
- )
+ assert logcm.output == [
+ f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"
+ ]
+
self.acme_cert.refresh_from_db()
assert self.acme_cert.cert is not None, "Check to make mypy happy"
self.order.refresh_from_db()
- self.assertEqual(self.order.status, AcmeOrder.STATUS_VALID)
- self.assertEqual(
- self.acme_cert.cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME],
- subject_alternative_name(x509.DNSName(self.hostname)),
- )
- self.assertEqual(
- self.acme_cert.cert.not_after, timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY
- )
- self.assertEqual(self.acme_cert.cert.cn, self.hostname)
- self.assertEqual(self.acme_cert.cert.profile, "client")
+ assert self.order.status == AcmeOrder.STATUS_VALID
+ assert self.acme_cert.cert.extensions[
+ ExtensionOID.SUBJECT_ALTERNATIVE_NAME
+ ] == subject_alternative_name(x509.DNSName(self.hostname))
+
+ assert self.acme_cert.cert.not_after == timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY
+
+ assert self.acme_cert.cert.cn == self.hostname
+ assert self.acme_cert.cert.profile == "client"
@freeze_time(TIMESTAMPS["everything_valid"])
@@ -546,27 +532,27 @@ def test_basic(self) -> None:
"""Basic test."""
tasks.acme_cleanup() # does nothing if nothing is expired
- self.assertEqual(self.acme_cert, AcmeCertificate.objects.get(pk=self.acme_cert.pk))
- self.assertEqual(self.order, AcmeOrder.objects.get(pk=self.order.pk))
- self.assertEqual(self.auth, AcmeAuthorization.objects.get(pk=self.auth.pk))
- self.assertEqual(self.account, AcmeAccount.objects.get(pk=self.account.pk))
+ assert self.acme_cert == AcmeCertificate.objects.get(pk=self.acme_cert.pk)
+ assert self.order == AcmeOrder.objects.get(pk=self.order.pk)
+ assert self.auth == AcmeAuthorization.objects.get(pk=self.auth.pk)
+ assert self.account == AcmeAccount.objects.get(pk=self.account.pk)
with self.freeze_time(timezone.now() + timedelta(days=3)):
tasks.acme_cleanup()
- self.assertEqual(AcmeOrder.objects.all().count(), 0)
- self.assertEqual(AcmeAuthorization.objects.all().count(), 0)
- self.assertEqual(AcmeChallenge.objects.all().count(), 0)
- self.assertEqual(AcmeCertificate.objects.all().count(), 0)
+ assert AcmeOrder.objects.all().count() == 0
+ assert AcmeAuthorization.objects.all().count() == 0
+ assert AcmeChallenge.objects.all().count() == 0
+ assert AcmeCertificate.objects.all().count() == 0
def test_acme_disabled(self) -> None:
"""Test task when ACME is disabled."""
with self.settings(CA_ENABLE_ACME=False), self.assertLogs() as logcm:
with self.freeze_time(timezone.now() + timedelta(days=3)):
tasks.acme_cleanup()
- self.assertEqual(logcm.output, ["INFO:django_ca.tasks:ACME is not enabled, not doing anything."])
+ assert logcm.output == ["INFO:django_ca.tasks:ACME is not enabled, not doing anything."]
- self.assertEqual(AcmeOrder.objects.all().count(), 1)
- self.assertEqual(AcmeAuthorization.objects.all().count(), 1)
- self.assertEqual(AcmeChallenge.objects.all().count(), 1)
- self.assertEqual(AcmeCertificate.objects.all().count(), 1)
+ assert AcmeOrder.objects.all().count() == 1
+ assert AcmeAuthorization.objects.all().count() == 1
+ assert AcmeChallenge.objects.all().count() == 1
+ assert AcmeCertificate.objects.all().count() == 1
diff --git a/ca/django_ca/tests/test_typehints.py b/ca/django_ca/tests/test_typehints.py
index 43cc16c88..d84c906e3 100644
--- a/ca/django_ca/tests/test_typehints.py
+++ b/ca/django_ca/tests/test_typehints.py
@@ -37,7 +37,7 @@ def test_end_entity_certificate_extension_keys() -> None:
@pytest.mark.parametrize(
- "extension_types,extensions",
+ ("extension_types", "extensions"),
(
(typehints.ConfigurableExtensionType, typehints.ConfigurableExtension),
(typehints.EndEntityCertificateExtensionType, typehints.EndEntityCertificateExtension),
diff --git a/ca/django_ca/tests/test_utils.py b/ca/django_ca/tests/test_utils.py
index dd944113d..632cb32f7 100644
--- a/ca/django_ca/tests/test_utils.py
+++ b/ca/django_ca/tests/test_utils.py
@@ -78,8 +78,8 @@ def test_read_file(tmpcadir: Path) -> None:
@pytest.mark.parametrize(
- "attributes,expected",
- [
+ ("attributes", "expected"),
+ (
([(NameOID.COMMON_NAME, "example.com")], [cn("example.com")]),
(
[(NameOID.COUNTRY_NAME, "AT"), (NameOID.COMMON_NAME, "example.com")],
@@ -89,7 +89,7 @@ def test_read_file(tmpcadir: Path) -> None:
[(NameOID.X500_UNIQUE_IDENTIFIER, "65:78:61:6D:70:6C:65")],
[x509.NameAttribute(NameOID.X500_UNIQUE_IDENTIFIER, b"example", _type=_ASN1Type.BitString)],
),
- ],
+ ),
)
def test_parse_serialized_name_attributes(
attributes: list[tuple[x509.ObjectIdentifier, str]], expected: list[x509.NameAttribute]
@@ -107,26 +107,26 @@ class GeneratePrivateKeyTestCase(TestCase):
def test_key_types(self) -> None:
"""Test generating various private key types."""
ec_key = generate_private_key(None, "EC", ec.BrainpoolP256R1())
- self.assertIsInstance(ec_key, ec.EllipticCurvePrivateKey)
- self.assertIsInstance(ec_key.curve, ec.BrainpoolP256R1)
+ assert isinstance(ec_key, ec.EllipticCurvePrivateKey)
+ assert isinstance(ec_key.curve, ec.BrainpoolP256R1)
ed448_key = generate_private_key(None, "Ed448", None)
- self.assertIsInstance(ed448_key, ed448.Ed448PrivateKey)
+ assert isinstance(ed448_key, ed448.Ed448PrivateKey)
def test_dsa_default_key_size(self) -> None:
"""Test the default DSA key size."""
key = generate_private_key(None, "DSA", None)
- self.assertIsInstance(key, dsa.DSAPrivateKey)
- self.assertEqual(key.key_size, model_settings.CA_DEFAULT_KEY_SIZE)
+ assert isinstance(key, dsa.DSAPrivateKey)
+ assert key.key_size == model_settings.CA_DEFAULT_KEY_SIZE
def test_invalid_type(self) -> None:
"""Test passing an invalid key type."""
- with self.assertRaisesRegex(ValueError, r"^FOO: Unknown key type\.$"):
+ with pytest.raises(ValueError, match=r"^FOO: Unknown key type\.$"):
generate_private_key(16, "FOO", None) # type: ignore[call-overload]
@pytest.mark.parametrize(
- "general_name,expected",
+ ("general_name", "expected"),
(
(dns("example.com"), "DNS:example.com"),
(x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), "IP:127.0.0.1"),
@@ -147,14 +147,11 @@ class SerializeName(TestCase):
def test_name(self) -> None:
"""Test passing a standard Name."""
- self.assertEqual(
- serialize_name(x509.Name([cn("example.com")])),
- [{"oid": "2.5.4.3", "value": "example.com"}],
- )
- self.assertEqual(
- serialize_name(x509.Name([country("AT"), cn("example.com")])),
- [{"oid": "2.5.4.6", "value": "AT"}, {"oid": "2.5.4.3", "value": "example.com"}],
- )
+ assert serialize_name(x509.Name([cn("example.com")])) == [{"oid": "2.5.4.3", "value": "example.com"}]
+ assert serialize_name(x509.Name([country("AT"), cn("example.com")])) == [
+ {"oid": "2.5.4.6", "value": "AT"},
+ {"oid": "2.5.4.3", "value": "example.com"},
+ ]
@unittest.skipIf(CRYPTOGRAPHY_VERSION < (37, 0), "cg<36 does not yet have bytes.")
def test_bytes(self) -> None:
@@ -162,19 +159,12 @@ def test_bytes(self) -> None:
name = x509.Name(
[x509.NameAttribute(NameOID.X500_UNIQUE_IDENTIFIER, b"example.com", _type=_ASN1Type.BitString)]
)
- self.assertEqual(
- serialize_name(name), [{"oid": "2.5.4.45", "value": "65:78:61:6D:70:6C:65:2E:63:6F:6D"}]
- )
+ assert serialize_name(name) == [{"oid": "2.5.4.45", "value": "65:78:61:6D:70:6C:65:2E:63:6F:6D"}]
@pytest.mark.parametrize(
- "value,expected",
- (
- ("PEM", Encoding.PEM),
- ("DER", Encoding.DER),
- ("ASN1", Encoding.DER),
- ("OpenSSH", Encoding.OpenSSH),
- ),
+ ("value", "expected"),
+ (("PEM", Encoding.PEM), ("DER", Encoding.DER), ("ASN1", Encoding.DER), ("OpenSSH", Encoding.OpenSSH)),
)
def test_parse_encoding(value: Any, expected: Encoding) -> None:
"""Test :py:func:`django_ca.utils.parse_encoding`."""
@@ -192,36 +182,36 @@ class AddColonsTestCase(TestCase):
def test_basic(self) -> None:
"""Some basic tests."""
- self.assertEqual(utils.add_colons(""), "")
- self.assertEqual(utils.add_colons("a"), "0a")
- self.assertEqual(utils.add_colons("ab"), "ab")
- self.assertEqual(utils.add_colons("abc"), "0a:bc")
- self.assertEqual(utils.add_colons("abcd"), "ab:cd")
- self.assertEqual(utils.add_colons("abcde"), "0a:bc:de")
- self.assertEqual(utils.add_colons("abcdef"), "ab:cd:ef")
- self.assertEqual(utils.add_colons("abcdefg"), "0a:bc:de:fg")
+ assert utils.add_colons("") == ""
+ assert utils.add_colons("a") == "0a"
+ assert utils.add_colons("ab") == "ab"
+ assert utils.add_colons("abc") == "0a:bc"
+ assert utils.add_colons("abcd") == "ab:cd"
+ assert utils.add_colons("abcde") == "0a:bc:de"
+ assert utils.add_colons("abcdef") == "ab:cd:ef"
+ assert utils.add_colons("abcdefg") == "0a:bc:de:fg"
def test_pad(self) -> None:
"""Test padding."""
- self.assertEqual(utils.add_colons("a", pad="z"), "za")
- self.assertEqual(utils.add_colons("ab", pad="z"), "ab")
- self.assertEqual(utils.add_colons("abc", pad="z"), "za:bc")
+ assert utils.add_colons("a", pad="z") == "za"
+ assert utils.add_colons("ab", pad="z") == "ab"
+ assert utils.add_colons("abc", pad="z") == "za:bc"
def test_no_pad(self) -> None:
"""Test disabling padding."""
- self.assertEqual(utils.add_colons("a", pad=""), "a")
- self.assertEqual(utils.add_colons("ab", pad=""), "ab")
- self.assertEqual(utils.add_colons("abc", pad=""), "ab:c")
+ assert utils.add_colons("a", pad="") == "a"
+ assert utils.add_colons("ab", pad="") == "ab"
+ assert utils.add_colons("abc", pad="") == "ab:c"
def test_zero_padding(self) -> None:
"""Test when there is no padding."""
- self.assertEqual(
- utils.add_colons("F570A555BC5000FA301E8C75FFB31684FCF64436"),
- "F5:70:A5:55:BC:50:00:FA:30:1E:8C:75:FF:B3:16:84:FC:F6:44:36",
+ assert (
+ utils.add_colons("F570A555BC5000FA301E8C75FFB31684FCF64436")
+ == "F5:70:A5:55:BC:50:00:FA:30:1E:8C:75:FF:B3:16:84:FC:F6:44:36"
)
- self.assertEqual(
- utils.add_colons("85BDA79A857379A4C9E910DAEA21C896D16394"),
- "85:BD:A7:9A:85:73:79:A4:C9:E9:10:DA:EA:21:C8:96:D1:63:94",
+ assert (
+ utils.add_colons("85BDA79A857379A4C9E910DAEA21C896D16394")
+ == "85:BD:A7:9A:85:73:79:A4:C9:E9:10:DA:EA:21:C8:96:D1:63:94"
)
@@ -230,75 +220,75 @@ class IntToHexTestCase(TestCase):
def test_basic(self) -> None:
"""Test the first view numbers."""
- self.assertEqual(utils.int_to_hex(0), "0")
- self.assertEqual(utils.int_to_hex(1), "1")
- self.assertEqual(utils.int_to_hex(2), "2")
- self.assertEqual(utils.int_to_hex(3), "3")
- self.assertEqual(utils.int_to_hex(4), "4")
- self.assertEqual(utils.int_to_hex(5), "5")
- self.assertEqual(utils.int_to_hex(6), "6")
- self.assertEqual(utils.int_to_hex(7), "7")
- self.assertEqual(utils.int_to_hex(8), "8")
- self.assertEqual(utils.int_to_hex(9), "9")
- self.assertEqual(utils.int_to_hex(10), "A")
- self.assertEqual(utils.int_to_hex(11), "B")
- self.assertEqual(utils.int_to_hex(12), "C")
- self.assertEqual(utils.int_to_hex(13), "D")
- self.assertEqual(utils.int_to_hex(14), "E")
- self.assertEqual(utils.int_to_hex(15), "F")
- self.assertEqual(utils.int_to_hex(16), "10")
- self.assertEqual(utils.int_to_hex(17), "11")
- self.assertEqual(utils.int_to_hex(18), "12")
- self.assertEqual(utils.int_to_hex(19), "13")
- self.assertEqual(utils.int_to_hex(20), "14")
- self.assertEqual(utils.int_to_hex(21), "15")
- self.assertEqual(utils.int_to_hex(22), "16")
- self.assertEqual(utils.int_to_hex(23), "17")
- self.assertEqual(utils.int_to_hex(24), "18")
- self.assertEqual(utils.int_to_hex(25), "19")
- self.assertEqual(utils.int_to_hex(26), "1A")
- self.assertEqual(utils.int_to_hex(27), "1B")
- self.assertEqual(utils.int_to_hex(28), "1C")
- self.assertEqual(utils.int_to_hex(29), "1D")
- self.assertEqual(utils.int_to_hex(30), "1E")
- self.assertEqual(utils.int_to_hex(31), "1F")
- self.assertEqual(utils.int_to_hex(32), "20")
- self.assertEqual(utils.int_to_hex(33), "21")
- self.assertEqual(utils.int_to_hex(34), "22")
- self.assertEqual(utils.int_to_hex(35), "23")
- self.assertEqual(utils.int_to_hex(36), "24")
- self.assertEqual(utils.int_to_hex(37), "25")
- self.assertEqual(utils.int_to_hex(38), "26")
- self.assertEqual(utils.int_to_hex(39), "27")
- self.assertEqual(utils.int_to_hex(40), "28")
- self.assertEqual(utils.int_to_hex(41), "29")
- self.assertEqual(utils.int_to_hex(42), "2A")
- self.assertEqual(utils.int_to_hex(43), "2B")
- self.assertEqual(utils.int_to_hex(44), "2C")
- self.assertEqual(utils.int_to_hex(45), "2D")
- self.assertEqual(utils.int_to_hex(46), "2E")
- self.assertEqual(utils.int_to_hex(47), "2F")
- self.assertEqual(utils.int_to_hex(48), "30")
- self.assertEqual(utils.int_to_hex(49), "31")
+ assert utils.int_to_hex(0) == "0"
+ assert utils.int_to_hex(1) == "1"
+ assert utils.int_to_hex(2) == "2"
+ assert utils.int_to_hex(3) == "3"
+ assert utils.int_to_hex(4) == "4"
+ assert utils.int_to_hex(5) == "5"
+ assert utils.int_to_hex(6) == "6"
+ assert utils.int_to_hex(7) == "7"
+ assert utils.int_to_hex(8) == "8"
+ assert utils.int_to_hex(9) == "9"
+ assert utils.int_to_hex(10) == "A"
+ assert utils.int_to_hex(11) == "B"
+ assert utils.int_to_hex(12) == "C"
+ assert utils.int_to_hex(13) == "D"
+ assert utils.int_to_hex(14) == "E"
+ assert utils.int_to_hex(15) == "F"
+ assert utils.int_to_hex(16) == "10"
+ assert utils.int_to_hex(17) == "11"
+ assert utils.int_to_hex(18) == "12"
+ assert utils.int_to_hex(19) == "13"
+ assert utils.int_to_hex(20) == "14"
+ assert utils.int_to_hex(21) == "15"
+ assert utils.int_to_hex(22) == "16"
+ assert utils.int_to_hex(23) == "17"
+ assert utils.int_to_hex(24) == "18"
+ assert utils.int_to_hex(25) == "19"
+ assert utils.int_to_hex(26) == "1A"
+ assert utils.int_to_hex(27) == "1B"
+ assert utils.int_to_hex(28) == "1C"
+ assert utils.int_to_hex(29) == "1D"
+ assert utils.int_to_hex(30) == "1E"
+ assert utils.int_to_hex(31) == "1F"
+ assert utils.int_to_hex(32) == "20"
+ assert utils.int_to_hex(33) == "21"
+ assert utils.int_to_hex(34) == "22"
+ assert utils.int_to_hex(35) == "23"
+ assert utils.int_to_hex(36) == "24"
+ assert utils.int_to_hex(37) == "25"
+ assert utils.int_to_hex(38) == "26"
+ assert utils.int_to_hex(39) == "27"
+ assert utils.int_to_hex(40) == "28"
+ assert utils.int_to_hex(41) == "29"
+ assert utils.int_to_hex(42) == "2A"
+ assert utils.int_to_hex(43) == "2B"
+ assert utils.int_to_hex(44) == "2C"
+ assert utils.int_to_hex(45) == "2D"
+ assert utils.int_to_hex(46) == "2E"
+ assert utils.int_to_hex(47) == "2F"
+ assert utils.int_to_hex(48) == "30"
+ assert utils.int_to_hex(49) == "31"
def test_high(self) -> None:
"""Test some high numbers."""
- self.assertEqual(utils.int_to_hex(1513282098), "5A32DA32")
- self.assertEqual(utils.int_to_hex(1513282099), "5A32DA33")
- self.assertEqual(utils.int_to_hex(1513282100), "5A32DA34")
- self.assertEqual(utils.int_to_hex(1513282101), "5A32DA35")
- self.assertEqual(utils.int_to_hex(1513282102), "5A32DA36")
- self.assertEqual(utils.int_to_hex(1513282103), "5A32DA37")
- self.assertEqual(utils.int_to_hex(1513282104), "5A32DA38")
- self.assertEqual(utils.int_to_hex(1513282105), "5A32DA39")
- self.assertEqual(utils.int_to_hex(1513282106), "5A32DA3A")
- self.assertEqual(utils.int_to_hex(1513282107), "5A32DA3B")
- self.assertEqual(utils.int_to_hex(1513282108), "5A32DA3C")
- self.assertEqual(utils.int_to_hex(1513282109), "5A32DA3D")
- self.assertEqual(utils.int_to_hex(1513282110), "5A32DA3E")
- self.assertEqual(utils.int_to_hex(1513282111), "5A32DA3F")
- self.assertEqual(utils.int_to_hex(1513282112), "5A32DA40")
- self.assertEqual(utils.int_to_hex(1513282113), "5A32DA41")
+ assert utils.int_to_hex(1513282098) == "5A32DA32"
+ assert utils.int_to_hex(1513282099) == "5A32DA33"
+ assert utils.int_to_hex(1513282100) == "5A32DA34"
+ assert utils.int_to_hex(1513282101) == "5A32DA35"
+ assert utils.int_to_hex(1513282102) == "5A32DA36"
+ assert utils.int_to_hex(1513282103) == "5A32DA37"
+ assert utils.int_to_hex(1513282104) == "5A32DA38"
+ assert utils.int_to_hex(1513282105) == "5A32DA39"
+ assert utils.int_to_hex(1513282106) == "5A32DA3A"
+ assert utils.int_to_hex(1513282107) == "5A32DA3B"
+ assert utils.int_to_hex(1513282108) == "5A32DA3C"
+ assert utils.int_to_hex(1513282109) == "5A32DA3D"
+ assert utils.int_to_hex(1513282110) == "5A32DA3E"
+ assert utils.int_to_hex(1513282111) == "5A32DA3F"
+ assert utils.int_to_hex(1513282112) == "5A32DA40"
+ assert utils.int_to_hex(1513282113) == "5A32DA41"
class BytesToHexTestCase(TestCase):
@@ -306,11 +296,11 @@ class BytesToHexTestCase(TestCase):
def test_basic(self) -> None:
"""Some basic test cases."""
- self.assertEqual(bytes_to_hex(b"test"), "74:65:73:74")
- self.assertEqual(bytes_to_hex(b"foo"), "66:6F:6F")
- self.assertEqual(bytes_to_hex(b"bar"), "62:61:72")
- self.assertEqual(bytes_to_hex(b""), "")
- self.assertEqual(bytes_to_hex(b"a"), "61")
+ assert bytes_to_hex(b"test") == "74:65:73:74"
+ assert bytes_to_hex(b"foo") == "66:6F:6F"
+ assert bytes_to_hex(b"bar") == "62:61:72"
+ assert bytes_to_hex(b"") == ""
+ assert bytes_to_hex(b"a") == "61"
class SanitizeSerialTestCase(TestCase):
@@ -318,23 +308,23 @@ class SanitizeSerialTestCase(TestCase):
def test_already_sanitized(self) -> None:
"""Test some already sanitized input."""
- self.assertEqual(utils.sanitize_serial("A"), "A")
- self.assertEqual(utils.sanitize_serial("5A32DA3B"), "5A32DA3B")
- self.assertEqual(utils.sanitize_serial("1234567890ABCDEF"), "1234567890ABCDEF")
+ assert utils.sanitize_serial("A") == "A"
+ assert utils.sanitize_serial("5A32DA3B") == "5A32DA3B"
+ assert utils.sanitize_serial("1234567890ABCDEF") == "1234567890ABCDEF"
def test_sanitized(self) -> None:
"""Test some input that can be correctly sanitized."""
- self.assertEqual(utils.sanitize_serial("5A:32:DA:3B"), "5A32DA3B")
- self.assertEqual(utils.sanitize_serial("0A:32:DA:3B"), "A32DA3B")
- self.assertEqual(utils.sanitize_serial("0a:32:da:3b"), "A32DA3B")
+ assert utils.sanitize_serial("5A:32:DA:3B") == "5A32DA3B"
+ assert utils.sanitize_serial("0A:32:DA:3B") == "A32DA3B"
+ assert utils.sanitize_serial("0a:32:da:3b") == "A32DA3B"
def test_zero(self) -> None:
"""An imported CA might have a serial of just a ``0``, so it must not be stripped."""
- self.assertEqual(utils.sanitize_serial("0"), "0")
+ assert utils.sanitize_serial("0") == "0"
def test_invalid_input(self) -> None:
"""Test some input that raises an exception."""
- with self.assertRaisesRegex(ValueError, r"^ABCXY: Serial has invalid characters$"):
+ with pytest.raises(ValueError, match=r"^ABCXY: Serial has invalid characters$"):
utils.sanitize_serial("ABCXY")
@@ -364,13 +354,13 @@ def test_str(self) -> None:
("CN", "example.com"),
("emailAddress", "user@example.com"),
]
- self.assertEqual(x509_name(subject), self.name)
+ assert x509_name(subject) == self.name
def test_multiple_other(self) -> None:
"""Test multiple other tokens (only OUs work)."""
- with self.assertRaisesRegex(ValueError, '^Subject contains multiple "countryName" fields$'):
+ with pytest.raises(ValueError, match='^Subject contains multiple "countryName" fields$'):
x509_name([("C", "AT"), ("C", "DE")])
- with self.assertRaisesRegex(ValueError, '^Subject contains multiple "commonName" fields$'):
+ with pytest.raises(ValueError, match='^Subject contains multiple "commonName" fields$'):
x509_name([("CN", "AT"), ("CN", "FOO")])
@@ -400,7 +390,7 @@ def assertMerged( # pylint: disable=invalid-name # unittest standard
base_name = x509.Name(base)
update_name = x509.Name(update)
merged_name = x509.Name(merged)
- self.assertEqual(merge_x509_names(base_name, update_name), merged_name)
+ assert merge_x509_names(base_name, update_name) == merged_name
def test_full_merge(self) -> None:
"""Test a basic merge."""
@@ -442,9 +432,9 @@ def test_unsortable_values(self) -> None:
"""Test merging unsortable values."""
sortable = x509.Name([self.cc1, self.common_name1])
unsortable = x509.Name([self.cc1, x509.NameAttribute(NameOID.INN, "unsortable")])
- with self.assertRaisesRegex(ValueError, r"Unsortable name"):
+ with pytest.raises(ValueError, match=r"Unsortable name"):
merge_x509_names(unsortable, sortable)
- with self.assertRaisesRegex(ValueError, r"Unsortable name"):
+ with pytest.raises(ValueError, match=r"Unsortable name"):
merge_x509_names(sortable, unsortable)
@@ -462,47 +452,43 @@ def test_basic(self) -> None:
# pylint: disable=protected-access; only way to test builder attributes
after = datetime(2020, 10, 23, 11, 21, tzinfo=tz.utc)
builder = get_cert_builder(after)
- self.assertEqual(builder._not_valid_before, datetime(2018, 11, 3, 11, 21))
- self.assertEqual(builder._not_valid_after, datetime(2020, 10, 23, 11, 21))
- self.assertIsInstance(builder._serial_number, int)
+ assert builder._not_valid_before == datetime(2018, 11, 3, 11, 21)
+ assert builder._not_valid_after == datetime(2020, 10, 23, 11, 21)
+ assert isinstance(builder._serial_number, int)
@freeze_time("2021-01-23 14:42:11.1234")
def test_datetime(self) -> None:
"""Basic tests."""
expires = datetime.now(tz.utc) + timedelta(days=10)
- self.assertNotEqual(expires.second, 0)
- self.assertNotEqual(expires.microsecond, 0)
+ assert expires.second != 0
+ assert expires.microsecond != 0
expires_expected = datetime(2021, 2, 2, 14, 42)
builder = get_cert_builder(expires)
- self.assertEqual(builder._not_valid_after, expires_expected) # pylint: disable=protected-access
- self.assertIsInstance(builder._serial_number, int) # pylint: disable=protected-access
+ assert builder._not_valid_after == expires_expected # pylint: disable=protected-access
+ assert isinstance(builder._serial_number, int) # pylint: disable=protected-access
@freeze_time("2021-01-23 14:42:11.1234")
def test_serial(self) -> None:
"""Test manually setting a serial."""
after = datetime(2022, 10, 23, 11, 21, tzinfo=tz.utc)
builder = get_cert_builder(after, serial=123)
- self.assertEqual(builder._serial_number, 123) # pylint: disable=protected-access
- self.assertEqual(
- builder._not_valid_after, # pylint: disable=protected-access
- datetime(2022, 10, 23, 11, 21),
- )
+ assert builder._serial_number == 123 # pylint: disable=protected-access
+ assert builder._not_valid_after == datetime(2022, 10, 23, 11, 21) # pylint: disable=protected-access
@freeze_time("2021-01-23 14:42:11")
def test_negative_datetime(self) -> None:
"""Test passing a datetime in the past."""
- msg = r"^not_after must be in the future$"
- with self.assertRaisesRegex(ValueError, msg):
+ with pytest.raises(ValueError, match=r"^not_after must be in the future$"):
get_cert_builder(datetime.now(tz.utc) - timedelta(seconds=60))
def test_invalid_type(self) -> None:
"""Test passing an invalid type."""
- with self.assertRaises(AttributeError):
+ with pytest.raises(AttributeError):
get_cert_builder("a string") # type: ignore[arg-type]
def test_naive_datetime(self) -> None:
"""Test passing a naive datetime."""
- with self.assertRaisesRegex(ValueError, r"^not_after must not be a naive datetime$"):
+ with pytest.raises(ValueError, match=r"^not_after must not be a naive datetime$"):
get_cert_builder(datetime.now())
@@ -530,40 +516,40 @@ def test_default_parameters(self) -> None:
def test_valid_parameters(self) -> None:
"""Test valid parameters."""
- self.assertEqual((8192, None), validate_private_key_parameters("RSA", 8192, None))
- self.assertEqual((8192, None), validate_private_key_parameters("DSA", 8192, None))
+ assert validate_private_key_parameters("RSA", 8192, None) == (8192, None)
+ assert validate_private_key_parameters("DSA", 8192, None) == (8192, None)
key_size, elliptic_curve = validate_private_key_parameters("EC", None, ec.BrainpoolP384R1())
- self.assertIsNone(key_size)
- self.assertIsInstance(elliptic_curve, ec.BrainpoolP384R1)
+ assert key_size is None
+ assert isinstance(elliptic_curve, ec.BrainpoolP384R1)
def test_wrong_values(self) -> None:
"""Test validating various bogus values."""
key_size = model_settings.CA_DEFAULT_KEY_SIZE
elliptic_curve = model_settings.CA_DEFAULT_ELLIPTIC_CURVE
- with self.assertRaisesRegex(ValueError, "^FOOBAR: Unknown key type$"):
+ with pytest.raises(ValueError, match="^FOOBAR: Unknown key type$"):
validate_private_key_parameters("FOOBAR", 4096, None) # type: ignore[call-overload]
- with self.assertRaisesRegex(ValueError, r"^foo: Key size must be an int\.$"):
+ with pytest.raises(ValueError, match=r"^foo: Key size must be an int\.$"):
validate_private_key_parameters("RSA", "foo", None) # type: ignore[call-overload]
- with self.assertRaisesRegex(ValueError, "^4000: Key size must be a power of two$"):
+ with pytest.raises(ValueError, match="^4000: Key size must be a power of two$"):
validate_private_key_parameters("RSA", 4000, None)
- with self.assertRaisesRegex(ValueError, "^16: Key size must be least 1024 bits$"):
+ with pytest.raises(ValueError, match="^16: Key size must be least 1024 bits$"):
validate_private_key_parameters("RSA", 16, None)
- with self.assertRaisesRegex(ValueError, r"^Key size is not supported for EC keys\.$"):
+ with pytest.raises(ValueError, match=r"^Key size is not supported for EC keys\.$"):
validate_private_key_parameters("EC", key_size, elliptic_curve)
- with self.assertRaisesRegex(ValueError, r"^secp192r1: Must be a subclass of ec\.EllipticCurve$"):
+ with pytest.raises(ValueError, match=r"^secp192r1: Must be a subclass of ec\.EllipticCurve$"):
validate_private_key_parameters("EC", None, "secp192r1") # type: ignore
for key_type in ("Ed448", "Ed25519"):
- with self.assertRaisesRegex(ValueError, rf"^Key size is not supported for {key_type} keys\.$"):
+ with pytest.raises(ValueError, match=rf"^Key size is not supported for {key_type} keys\.$"):
validate_private_key_parameters(key_type, key_size, None) # type: ignore
- with self.assertRaisesRegex(
- ValueError, rf"^Elliptic curves are not supported for {key_type} keys\.$"
+ with pytest.raises(
+ ValueError, match=rf"^Elliptic curves are not supported for {key_type} keys\.$"
):
validate_private_key_parameters(key_type, None, elliptic_curve) # type: ignore
@@ -581,14 +567,14 @@ def test_valid_parameters(self) -> None:
def test_invalid_parameters(self) -> None:
"""Test invalid parameters."""
- with self.assertRaisesRegex(ValueError, "^FOOBAR: Unknown key type$"):
+ with pytest.raises(ValueError, match="^FOOBAR: Unknown key type$"):
validate_public_key_parameters("FOOBAR", None) # type: ignore[arg-type]
for key_type in ("RSA", "DSA", "EC"):
msg = rf"^{key_type}: algorithm must be an instance of hashes.HashAlgorithm\.$"
- with self.assertRaisesRegex(ValueError, msg):
+ with pytest.raises(ValueError, match=msg):
validate_public_key_parameters(key_type, True) # type: ignore[arg-type]
for key_type in ("Ed448", "Ed25519"):
msg = rf"^{key_type} keys do not allow an algorithm for signing\.$"
- with self.assertRaisesRegex(ValueError, msg):
+ with pytest.raises(ValueError, match=msg):
validate_public_key_parameters(key_type, hashes.SHA256()) # type: ignore[arg-type]
diff --git a/ca/django_ca/tests/test_views_ocsp.py b/ca/django_ca/tests/test_views_ocsp.py
index dcf347574..61b8572a3 100644
--- a/ca/django_ca/tests/test_views_ocsp.py
+++ b/ca/django_ca/tests/test_views_ocsp.py
@@ -42,13 +42,7 @@
from django_ca.key_backends.storages import StoragesUsePrivateKeyOptions
from django_ca.modelfields import LazyCertificate
from django_ca.models import Certificate, CertificateAuthority
-from django_ca.tests.base.constants import (
- CERT_DATA,
- CRYPTOGRAPHY_VERSION,
- FIXTURES_DATA,
- FIXTURES_DIR,
- TIMESTAMPS,
-)
+from django_ca.tests.base.constants import CERT_DATA, FIXTURES_DATA, FIXTURES_DIR, TIMESTAMPS
from django_ca.tests.base.mixins import TestCaseMixin
from django_ca.tests.base.typehints import HttpResponse
from django_ca.tests.base.utils import override_tmpcadir
@@ -190,12 +184,14 @@ def assertOCSPSignature( # pylint: disable=invalid-name
if isinstance(public_key, rsa.RSAPublicKey):
hash_algorithm = typing.cast(hashes.HashAlgorithm, hash_algorithm) # to make mypy happy
- self.assertIsNone(
+ assert (
public_key.verify(response.signature, tbs_response, padding.PKCS1v15(), hash_algorithm)
+ is None
)
+
elif isinstance(public_key, ec.EllipticCurvePublicKey):
hash_algorithm = typing.cast(hashes.HashAlgorithm, hash_algorithm) # to make mypy happy
- self.assertIsNone(public_key.verify(response.signature, tbs_response, ec.ECDSA(hash_algorithm)))
+ assert public_key.verify(response.signature, tbs_response, ec.ECDSA(hash_algorithm)) is None
elif isinstance(public_key, dsa.DSAPublicKey):
hash_algorithm = typing.cast(hashes.HashAlgorithm, hash_algorithm) # to make mypy happy
public_key.verify(response.signature, tbs_response, hash_algorithm)
@@ -213,22 +209,13 @@ def assertCertificateStatus( # pylint: disable=invalid-name
) -> None:
"""Check information related to the certificate status."""
if certificate.revoked is False:
- self.assertEqual(response.certificate_status, ocsp.OCSPCertStatus.GOOD)
- if CRYPTOGRAPHY_VERSION < (43,): # pragma: only cg<=42
- self.assertIsNone(response.revocation_time)
- else:
- self.assertIsNone(response.revocation_time_utc)
- self.assertIsNone(response.revocation_reason)
+ assert response.certificate_status == ocsp.OCSPCertStatus.GOOD
+ assert response.revocation_time_utc is None
+ assert response.revocation_reason is None
else:
- self.assertEqual(response.certificate_status, ocsp.OCSPCertStatus.REVOKED)
- self.assertEqual(response.revocation_reason, certificate.get_revocation_reason())
- if CRYPTOGRAPHY_VERSION < (43,): # pragma: only cg<=42
- self.assertEqual(
- response.revocation_time.replace(tzinfo=timezone.utc), # type: ignore[union-attr]
- certificate.get_revocation_time(),
- )
- else:
- self.assertEqual(response.revocation_time_utc, certificate.get_revocation_time())
+ assert response.certificate_status == ocsp.OCSPCertStatus.REVOKED
+ assert response.revocation_reason == certificate.get_revocation_reason()
+ assert response.revocation_time_utc == certificate.get_revocation_time()
def assertOCSPSingleResponse( # pylint: disable=invalid-name
self,
@@ -241,8 +228,8 @@ def assertOCSPSingleResponse( # pylint: disable=invalid-name
Note that `hash_algorithm` cannot be ``None``, as it must match the algorithm of the OCSP request.
"""
self.assertCertificateStatus(certificate, response)
- self.assertEqual(response.serial_number, certificate.pub.loaded.serial_number)
- self.assertIsInstance(response.hash_algorithm, hash_algorithm)
+ assert response.serial_number == certificate.pub.loaded.serial_number
+ assert isinstance(response.hash_algorithm, hash_algorithm)
def assertOCSPResponse( # pylint: disable=invalid-name
self,
@@ -260,47 +247,43 @@ def assertOCSPResponse( # pylint: disable=invalid-name
if responder_certificate is None:
responder_certificate = self.certs["profile-ocsp"]
- self.assertEqual(http_response["Content-Type"], "application/ocsp-response")
+ assert http_response["Content-Type"] == "application/ocsp-response"
response = ocsp.load_der_ocsp_response(http_response.content)
- self.assertEqual(response.response_status, response_status)
+ assert response.response_status == response_status
if signature_hash_algorithm is None:
- self.assertIsNone(response.signature_hash_algorithm)
+ assert response.signature_hash_algorithm is None
else:
- self.assertIsInstance(response.signature_hash_algorithm, signature_hash_algorithm)
- self.assertEqual(response.signature_algorithm_oid, signature_algorithm_oid)
- self.assertEqual(response.certificates, [responder_certificate.pub.loaded]) # responder certificate!
- self.assertIsNone(response.responder_name)
- self.assertIsInstance(response.responder_key_hash, bytes) # TODO: Validate responder id
+ assert isinstance(response.signature_hash_algorithm, signature_hash_algorithm)
+ assert response.signature_algorithm_oid == signature_algorithm_oid
+ assert response.certificates == [responder_certificate.pub.loaded] # responder certificate!
+ assert response.responder_name is None
+ assert isinstance(response.responder_key_hash, bytes) # TODO: Validate responder id
# TODO: validate issuer_key_hash, issuer_name_hash
# Check TIMESTAMPS
# self.assertEqual(response.produced_at, datetime.now())
- if CRYPTOGRAPHY_VERSION < (43,): # pragma: only cg<=42
- self.assertEqual(response.this_update, datetime.now())
- self.assertEqual(response.next_update, datetime.now() + timedelta(seconds=expires))
- else:
- now = datetime.now(tz=timezone.utc)
- self.assertEqual(response.this_update_utc, now)
- self.assertEqual(response.next_update_utc, now + timedelta(seconds=expires))
+ now = datetime.now(tz=timezone.utc)
+ assert response.this_update_utc == now
+ assert response.next_update_utc == now + timedelta(seconds=expires)
# Check nonce if passed
if nonce is None:
- self.assertEqual(len(response.extensions), 0)
+ assert len(response.extensions) == 0
else:
nonce_extension = response.extensions.get_extension_for_oid(OCSPExtensionOID.NONCE)
- self.assertIs(nonce_extension.critical, False)
- self.assertEqual(nonce_extension.value.nonce, nonce) # type: ignore[attr-defined]
+ assert nonce_extension.critical is False
+ assert nonce_extension.value.nonce == nonce # type: ignore[attr-defined]
- self.assertEqual(response.serial_number, requested_certificate.pub.loaded.serial_number)
+ assert response.serial_number == requested_certificate.pub.loaded.serial_number
# Check the certificate status
self.assertCertificateStatus(requested_certificate, response)
# Assert single response
single_responses = list(response.responses) # otherwise it has no len()/index
- self.assertEqual(len(single_responses), 1)
+ assert len(single_responses) == 1
self.assertOCSPSingleResponse(
requested_certificate, single_responses[0], single_response_hash_algorithm
)
@@ -345,7 +328,7 @@ def ocsp_get(
},
)
response = self.client.get(url)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
return response
@@ -363,7 +346,7 @@ def test_get(self) -> None:
"""Basic GET test."""
data = base64.b64encode(req1).decode("utf-8")
response = self.client.get(reverse("get", kwargs={"data": data}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertOCSPResponse(
response,
requested_certificate=self.cert,
@@ -375,9 +358,9 @@ def test_get(self) -> None:
def test_bad_query(self) -> None:
"""Test sending a bad query."""
response = self.client.get(reverse("get", kwargs={"data": "XXX"}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.MALFORMED_REQUEST)
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.MALFORMED_REQUEST
def test_raises_exception(self) -> None:
"""Generic test if the handling function throws any uncaught exception."""
@@ -389,26 +372,26 @@ def test_raises_exception(self) -> None:
with mock.patch(view_path, side_effect=ex), self.assertLogs() as logcm:
response = self.client.get(reverse("get", kwargs={"data": data}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
- self.assertEqual(len(logcm.output), 1)
- self.assertIn(exception_str, logcm.output[0])
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
+ assert len(logcm.output) == 1
+ assert exception_str in logcm.output[0]
# also do a post request
with mock.patch(view_path, side_effect=ex), self.assertLogs() as logcm:
response = self.client.post(reverse("post"), req1, content_type="application/ocsp-request")
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
- self.assertEqual(len(logcm.output), 1)
- self.assertIn(exception_str, logcm.output[0])
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
+ assert len(logcm.output) == 1
+ assert exception_str in logcm.output[0]
@override_tmpcadir()
def test_post(self) -> None:
"""Test the post request."""
response = self.client.post(reverse("post"), req1, content_type="application/ocsp-request")
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertOCSPResponse(
response,
requested_certificate=self.cert,
@@ -423,7 +406,7 @@ def test_post(self) -> None:
content_type="application/ocsp-request",
single_response_hash_algorithm=hashes.SHA1,
)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertOCSPResponse(
response,
requested_certificate=self.cert,
@@ -433,7 +416,7 @@ def test_post(self) -> None:
)
response = self.client.post(reverse("post-full-pem"), req1, content_type="application/ocsp-request")
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertOCSPResponse(
response,
requested_certificate=self.cert,
@@ -448,7 +431,7 @@ def test_loaded_cryptography_cert(self) -> None:
response = self.client.post(
reverse("post-loaded-cryptography"), req1, content_type="application/ocsp-request"
)
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertOCSPResponse(
response,
requested_certificate=self.cert,
@@ -463,7 +446,7 @@ def test_revoked(self) -> None:
self.cert.revoke()
response = self.client.post(reverse("post"), req1, content_type="application/ocsp-request")
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertOCSPResponse(
response,
requested_certificate=self.cert,
@@ -474,7 +457,7 @@ def test_revoked(self) -> None:
self.cert.revoke(ReasonFlags.affiliation_changed)
response = self.client.post(reverse("post"), req1, content_type="application/ocsp-request")
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertOCSPResponse(
response,
requested_certificate=self.cert,
@@ -494,7 +477,7 @@ def test_ca_ocsp(self) -> None:
data = base64.b64encode(req1).decode("utf-8")
response = self.client.get(reverse("get-ca", kwargs={"data": data}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
self.assertOCSPResponse(
response,
requested_certificate=ca,
@@ -508,31 +491,22 @@ def test_bad_ca(self) -> None:
data = base64.b64encode(req1).decode("utf-8")
with self.assertLogs() as logcm:
response = self.client.get(reverse("unknown", kwargs={"data": data}))
- self.assertEqual(
- logcm.output,
- [
- "ERROR:django_ca.views:unknown: Certificate Authority could not be found.",
- ],
- )
+ assert logcm.output == ["ERROR:django_ca.views:unknown: Certificate Authority could not be found."]
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
def test_unknown(self) -> None:
"""Test fetching data for an unknown certificate."""
data = base64.b64encode(unknown_req).decode("utf-8")
with self.assertLogs() as logcm:
response = self.client.get(reverse("get", kwargs={"data": data}))
- self.assertEqual(
- logcm.output,
- [
- "WARNING:django_ca.views:7B: OCSP request for unknown cert received.",
- ],
- )
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert logcm.output == ["WARNING:django_ca.views:7B: OCSP request for unknown cert received."]
+
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
@override_tmpcadir()
def test_unknown_ca(self) -> None:
@@ -541,12 +515,11 @@ def test_unknown_ca(self) -> None:
with self.assertLogs() as logcm:
response = self.client.get(reverse("get-ca", kwargs={"data": data}))
serial = self.certs["child-cert"].serial
- self.assertEqual(
- logcm.output, [f"WARNING:django_ca.views:{serial}: OCSP request for unknown CA received."]
- )
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert logcm.output, [f"WARNING:django_ca.views:{serial}: OCSP request for unknown CA received."]
+
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
@override_tmpcadir()
def test_bad_private_key_type(self) -> None:
@@ -563,14 +536,11 @@ def test_bad_private_key_type(self) -> None:
):
response = self.client.get(reverse("get", kwargs={"data": data}))
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
- self.assertEqual(
- logcm.output,
- [
- "ERROR:django_ca.views:: Unsupported private key type.",
- "ERROR:django_ca.views:Could not read responder key/cert.",
- ],
- )
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
+ assert logcm.output == [
+ "ERROR:django_ca.views:: Unsupported private key type.",
+ "ERROR:django_ca.views:Could not read responder key/cert.",
+ ]
def test_bad_responder_cert(self) -> None:
"""Test the error when the private key cannot be read.
@@ -581,32 +551,32 @@ def test_bad_responder_cert(self) -> None:
with self.assertLogs() as logcm:
response = self.client.get(reverse("get", kwargs={"data": data}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
- self.assertEqual(logcm.output, ["ERROR:django_ca.views:Could not read responder key/cert."])
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
+ assert logcm.output == ["ERROR:django_ca.views:Could not read responder key/cert."]
def test_bad_request(self) -> None:
"""Try making a bad request."""
data = base64.b64encode(b"foobar").decode("utf-8")
with self.assertLogs() as logcm:
response = self.client.get(reverse("get", kwargs={"data": data}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.MALFORMED_REQUEST)
- self.assertEqual(len(logcm.output), 1)
- self.assertIn("ValueError: error parsing asn1 value", logcm.output[0], logcm.output[0])
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.MALFORMED_REQUEST
+ assert len(logcm.output) == 1
+ assert "ValueError: error parsing asn1 value" in logcm.output[0], logcm.output[0]
def test_multiple(self) -> None:
"""Try making multiple OCSP requests (not currently supported)."""
data = base64.b64encode(multiple_req).decode("utf-8")
with self.assertLogs() as logcm:
response = self.client.get(reverse("get", kwargs={"data": data}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.MALFORMED_REQUEST)
- self.assertEqual(len(logcm.output), 1)
- self.assertIn("OCSP request contains more than one request", logcm.output[0])
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.MALFORMED_REQUEST
+ assert len(logcm.output) == 1
+ assert "OCSP request contains more than one request" in logcm.output[0]
@override_tmpcadir()
def test_bad_ca_cert(self) -> None:
@@ -618,11 +588,11 @@ def test_bad_ca_cert(self) -> None:
data = base64.b64encode(req1).decode("utf-8")
with self.assertLogs() as logcm:
response = self.client.get(reverse("get", kwargs={"data": data}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
- self.assertEqual(len(logcm.output), 1)
- self.assertIn("ValueError: ", logcm.output[0])
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
+ assert len(logcm.output) == 1
+ assert "ValueError: " in logcm.output[0]
@override_tmpcadir()
def test_bad_responder_key(self) -> None:
@@ -631,10 +601,10 @@ def test_bad_responder_key(self) -> None:
with self.assertLogs() as logcm:
response = self.client.get(reverse("false-key", kwargs={"data": data}))
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
- self.assertEqual(logcm.output, ["ERROR:django_ca.views:Could not read responder key/cert."])
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
+ assert logcm.output == ["ERROR:django_ca.views:Could not read responder key/cert."]
@override_tmpcadir()
def test_bad_responder_pem(self) -> None:
@@ -644,16 +614,16 @@ def test_bad_responder_pem(self) -> None:
with self.assertLogs() as logcm:
response = self.client.get(reverse("false-pem-serial", kwargs={"data": data}))
- self.assertEqual(logcm.output, [msg])
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert logcm.output == [msg]
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
with self.assertLogs() as logcm:
response = self.client.get(reverse("false-pem-full", kwargs={"data": data}))
- self.assertEqual(logcm.output, [msg])
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert logcm.output == [msg]
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
@override_settings(ROOT_URLCONF=__name__)
@@ -810,10 +780,10 @@ def test_invalid_responder_key(self) -> None:
with self.assertLogs() as logcm:
response = self.ocsp_get(self.cert, hash_algorithm=hashes.SHA512)
- self.assertEqual(logcm.output, ["ERROR:django_ca.views:Could not read responder key/cert."])
- self.assertEqual(response.status_code, HTTPStatus.OK)
+ assert logcm.output == ["ERROR:django_ca.views:Could not read responder key/cert."]
+ assert response.status_code == HTTPStatus.OK
ocsp_response = ocsp.load_der_ocsp_response(response.content)
- self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR)
+ assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR
@override_tmpcadir()
def test_ed25519_certificate_authority(self) -> None:
@@ -837,8 +807,8 @@ def test_cert_method_not_allowed(self) -> None:
"""Try HTTP methods that are not allowed."""
url = reverse("django_ca:ocsp-cert-post", kwargs={"serial": "00AA"})
response = self.client.get(url)
- self.assertEqual(response.status_code, 405)
+ assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED # 405
url = reverse("django_ca:ocsp-cert-get", kwargs={"serial": "00AA", "data": "irrelevant"})
response = self.client.post(url, req1, content_type="application/ocsp-request")
- self.assertEqual(response.status_code, 405)
+ assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED # 405
diff --git a/ca/django_ca/tests/utils/test_get_crl_cache_key.py b/ca/django_ca/tests/utils/test_get_crl_cache_key.py
index 22b293c1e..e35e361c0 100644
--- a/ca/django_ca/tests/utils/test_get_crl_cache_key.py
+++ b/ca/django_ca/tests/utils/test_get_crl_cache_key.py
@@ -33,7 +33,7 @@
@pytest.mark.parametrize(
- "kwargs,expected",
+ ("kwargs", "expected"),
(
(DEFAULT_KWARGS, "crl_123_DER_False_False_False_None"),
({**DEFAULT_KWARGS, "encoding": Encoding.PEM}, "crl_123_PEM_False_False_False_None"),
diff --git a/ca/django_ca/tests/utils/test_othername.py b/ca/django_ca/tests/utils/test_othername.py
index b0d07db9a..967a2b89d 100644
--- a/ca/django_ca/tests/utils/test_othername.py
+++ b/ca/django_ca/tests/utils/test_othername.py
@@ -23,7 +23,7 @@
@pytest.mark.parametrize(
- "value,expected,normalized",
+ ("value", "expected", "normalized"),
(
("UNIVERSALSTRING:ex", b"\x1c\x08\x00\x00\x00e\x00\x00\x00x", True),
("UNIV:ex", b"\x1c\x08\x00\x00\x00e\x00\x00\x00x", False),
@@ -59,7 +59,8 @@ def test_parse_and_format_othername(value: str, expected: bytes, normalized: boo
@pytest.mark.parametrize("typ", ("UTF8", "UTF8String"))
@pytest.mark.parametrize(
- "value,expected", (("example", b"\x0c\x07example"), ("example;wrong:val", b"\x0c\x11example;wrong:val"))
+ ("value", "expected"),
+ (("example", b"\x0c\x07example"), ("example;wrong:val", b"\x0c\x11example;wrong:val")),
)
def test_othername_with_utf8(typ: str, value: str, expected: bytes) -> None:
"""Test UTF8 values."""
@@ -88,7 +89,7 @@ def test_othername_with_boolean_false(typ: str, value: str) -> None:
@pytest.mark.parametrize("typ", ("INT", "INTEGER"))
@pytest.mark.parametrize(
- "raw_value,expected_bytes,formatted_value",
+ ("raw_value", "expected_bytes", "formatted_value"),
(
("0", b"\x02\x01\x00", "0"),
("1", b"\x02\x01\x01", "1"),
@@ -104,7 +105,7 @@ def test_othername_integer(typ: str, raw_value: str, expected_bytes: bytes, form
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
(
"2.4.5.3;BOOL:WRONG",
diff --git a/ca/django_ca/tests/utils/test_parse_general_name.py b/ca/django_ca/tests/utils/test_parse_general_name.py
index 5e9651ccd..fd7c61e96 100644
--- a/ca/django_ca/tests/utils/test_parse_general_name.py
+++ b/ca/django_ca/tests/utils/test_parse_general_name.py
@@ -27,7 +27,7 @@
@pytest.mark.parametrize("prefix", ("", "ip:"))
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
("1.2.3.4", IPv4Address("1.2.3.4")),
("1.2.3.0/24", IPv4Network("1.2.3.0/24")),
@@ -47,7 +47,7 @@ def test_ip(
@pytest.mark.parametrize("prefix", ("", "DNS:"))
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
("example.com", dns("example.com")),
(".example.com", dns(".example.com")),
diff --git a/ca/django_ca/tests/utils/test_parse_name_rfc4514.py b/ca/django_ca/tests/utils/test_parse_name_rfc4514.py
index 5445fc7e5..b4ccfccf8 100644
--- a/ca/django_ca/tests/utils/test_parse_name_rfc4514.py
+++ b/ca/django_ca/tests/utils/test_parse_name_rfc4514.py
@@ -24,7 +24,7 @@
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
("CN=example.com", x509.Name([cn("example.com")])),
(f"{NameOID.COMMON_NAME.dotted_string}=example.com", x509.Name([cn("example.com")])),
@@ -37,7 +37,7 @@ def test_parse_name_rfc4514(value: str, expected: x509.Name) -> None:
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
(
"C=FOO",
@@ -57,7 +57,7 @@ def test_parse_name_rfc4514_with_error(value: str, expected: str) -> None:
@pytest.mark.skipif(CRYPTOGRAPHY_VERSION < (43,), reason="cryptography check was added in version 43")
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
("CN=", r"^Attribute's length must be >= 1 and <= 64, but it was 0$"),
(f"CN={'x' * 65}", r"^Attribute's length must be >= 1 and <= 64, but it was 65$"),
diff --git a/ca/django_ca/tests/utils/test_parse_name_x509.py b/ca/django_ca/tests/utils/test_parse_name_x509.py
index 03edf5a96..dd6353dc3 100644
--- a/ca/django_ca/tests/utils/test_parse_name_x509.py
+++ b/ca/django_ca/tests/utils/test_parse_name_x509.py
@@ -22,7 +22,7 @@
@pytest.mark.parametrize(
- "value,expected",
+ ("value", "expected"),
(
("/CN=example.com", [(NameOID.COMMON_NAME, "example.com")]),
# leading or trailing spaces are always ok:
@@ -93,16 +93,6 @@
("/O=/OU=", [(NameOID.ORGANIZATION_NAME, ""), (NameOID.ORGANIZATIONAL_UNIT_NAME, "")]),
# no slash at start works:
("CN=example.com", [(NameOID.COMMON_NAME, "example.com")]),
- # test multiple OUs
- (
- "/C=AT/OU=foo/OU=bar/CN=example.com",
- [
- (NameOID.COUNTRY_NAME, "AT"),
- (NameOID.ORGANIZATIONAL_UNIT_NAME, "foo"),
- (NameOID.ORGANIZATIONAL_UNIT_NAME, "bar"),
- (NameOID.COMMON_NAME, "example.com"),
- ],
- ),
(
"/OU=foo/OU=bar",
[(NameOID.ORGANIZATIONAL_UNIT_NAME, "foo"), (NameOID.ORGANIZATIONAL_UNIT_NAME, "bar")],
diff --git a/ca/django_ca/tests/utils/test_split_str.py b/ca/django_ca/tests/utils/test_split_str.py
index b60e7a89a..ed4dcecc0 100644
--- a/ca/django_ca/tests/utils/test_split_str.py
+++ b/ca/django_ca/tests/utils/test_split_str.py
@@ -19,7 +19,7 @@
@pytest.mark.parametrize(
- "value,seperator,expected",
+ ("value", "seperator", "expected"),
(
("foo", "/", ["foo"]),
("foo bar", "/", ["foo bar"]),
@@ -37,10 +37,7 @@
("/foo/bar", "/", ["foo", "bar"]),
("/foo/bar/", "/", ["foo", "bar"]),
("/C=AT/CN=example.com/", "/", ["C=AT", "CN=example.com"]),
- (r"foo/bar", "/", ["foo", "bar"]),
# test quoting
- (r"foo'/'bar", "/", ["foo/bar"]),
- (r'foo"/"bar', "/", ["foo/bar"]),
(r'fo"o/b"ar', "/", ["foo/bar"]),
(r'"foo\"bar"', "/", ['foo"bar']), # escape quotes inside quotes
# Test the escape character
@@ -60,12 +57,6 @@
(r'"foo\\xbar"', "/", [r"foo\xbar"]),
# ... but in single quote it's not an escape -> double backslash in result
(r"'foo\\xbar'", "/", [r"foo\\xbar"]),
- # No quotes, single backslash preceeding "/" --> "/" is escaped
- (r"foo\/bar", "/", ["foo/bar"]),
- # No quotes, but *double* backslash preceeding "/" --> backslash itself is escaped, slash is delimiter
- (r"foo\\/bar", "/", ["foo\\", "bar"]),
- # With quotes/double quotes, no backslashes -> slash is inside quoted string -> it's not a delimiter
- ('"foo/bar"/bla', "/", ["foo/bar", "bla"]),
("'foo/bar'/bla", "/", ["foo/bar", "bla"]),
# With quotes/double quotes, with one backslash
(r'"foo\/bar"/bla', "/", [r"foo\/bar", "bla"]),
@@ -98,7 +89,7 @@ def test_basic(value: str, seperator: str, expected: list[str]) -> None:
@pytest.mark.parametrize(
- "value,match",
+ ("value", "match"),
(
(r"'foo\'bar'", "^No closing quotation$"),
(r"foo'bar", "^No closing quotation$"),
diff --git a/ca/django_ca/tests/utils/test_validate_hostname.py b/ca/django_ca/tests/utils/test_validate_hostname.py
index 4bd2efdd2..06eeeed19 100644
--- a/ca/django_ca/tests/utils/test_validate_hostname.py
+++ b/ca/django_ca/tests/utils/test_validate_hostname.py
@@ -55,7 +55,7 @@ def test_no_allow_port(value: str) -> None:
@pytest.mark.parametrize(
- "value,error",
+ ("value", "error"),
(
("localhost:no-int", "^no-int: Port must be an integer$"),
("localhost:0", "^0: Port must be between 1 and 65535$"),
diff --git a/ca/django_ca/tests/views/test_certificate_revocation_list_view.py b/ca/django_ca/tests/views/test_certificate_revocation_list_view.py
index e33a92db5..497e3dfd7 100644
--- a/ca/django_ca/tests/views/test_certificate_revocation_list_view.py
+++ b/ca/django_ca/tests/views/test_certificate_revocation_list_view.py
@@ -92,7 +92,7 @@
@pytest.fixture
-def default_url(root: CertificateAuthority) -> Iterator[str]:
+def default_url(root: CertificateAuthority) -> str:
"""Fixture for the default URL for the root CA."""
return reverse("default", kwargs={"serial": root.serial})
diff --git a/ca/django_ca/views.py b/ca/django_ca/views.py
index c8a00b4ea..90cc09fcb 100644
--- a/ca/django_ca/views.py
+++ b/ca/django_ca/views.py
@@ -147,9 +147,7 @@ def get_key_backend_options(self, ca: CertificateAuthority) -> BaseModel:
def fetch_crl(self, ca: CertificateAuthority, encoding: CertificateRevocationListEncodings) -> bytes:
"""Actually fetch the CRL (nested function so that we can easily catch any exception)."""
- print(self.scope)
if self.scope is not _NOT_SET:
- print(2)
warnings.warn(
"The scope parameter is deprecated and will be removed in django-ca 2.3.0, use "
"`only_contains_{ca,user,attribute}_cert` instead.",