diff --git a/auth/lib/charms/data_platform_libs/v0/data_interfaces.py b/auth/lib/charms/data_platform_libs/v0/data_interfaces.py index 83e1cf3..aa79814 100644 --- a/auth/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/auth/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -453,7 +453,7 @@ def _on_subject_requested(self, event: SubjectRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 55 +LIBPATCH = 56 PYDEPS = ["ops>=2.0.0"] @@ -2071,6 +2071,7 @@ def __init__( requested_entity_secret: Optional[str] = None, requested_entity_name: Optional[str] = None, requested_entity_password: Optional[str] = None, + prefix_matching: Optional[str] = None, ): """Manager of base client relations.""" super().__init__(model, relation_name) @@ -2081,6 +2082,7 @@ def __init__( self.requested_entity_secret = requested_entity_secret self.requested_entity_name = requested_entity_name self.requested_entity_password = requested_entity_password + self.prefix_matching = prefix_matching if ( self.requested_entity_secret or self.requested_entity_name @@ -3258,6 +3260,14 @@ def requested_entity_secret_content(self) -> Optional[Dict[str, Optional[str]]]: logger.warning("Invalid requested-entity-secret: no entity name") return names + @property + def prefix_matching(self) -> Optional[str]: + """Returns the prefix matching strategy that were requested.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("prefix-matching") + class DatabaseEntityRequestedEvent(DatabaseProvidesEvent, EntityProvidesEvent): """Event emitted when a new entity is requested for use on this relation.""" @@ -3364,6 +3374,16 @@ def version(self) -> Optional[str]: return self.relation.data[self.relation.app].get("version") + @property + def prefix_databases(self) -> Optional[List[str]]: + """Returns a list of databases matching a prefix.""" + if not self.relation.app: + return None + + if prefixed_databases := self.relation.data[self.relation.app].get("prefix-databases"): + return prefixed_databases.split(",") + return [] + class DatabaseCreatedEvent(AuthenticationEvent, DatabaseRequiresEvent): """Event emitted when a new database is created for use on this relation.""" @@ -3381,6 +3401,10 @@ class DatabaseReadOnlyEndpointsChangedEvent(AuthenticationEvent, DatabaseRequire """Event emitted when the read only endpoints are changed.""" +class DatabasePrefixDatabasesChangedEvent(AuthenticationEvent, DatabaseRequiresEvent): + """Event emitted when the prefix databases are changed.""" + + class DatabaseRequiresEvents(RequirerCharmEvents): """Database events. @@ -3391,6 +3415,7 @@ class DatabaseRequiresEvents(RequirerCharmEvents): database_entity_created = EventSource(DatabaseEntityCreatedEvent) endpoints_changed = EventSource(DatabaseEndpointsChangedEvent) read_only_endpoints_changed = EventSource(DatabaseReadOnlyEndpointsChangedEvent) + prefix_databases_changed = EventSource(DatabasePrefixDatabasesChangedEvent) # Database Provider and Requires @@ -3416,6 +3441,18 @@ def set_database(self, relation_id: int, database_name: str) -> None: """ self.update_relation_data(relation_id, {"database": database_name}) + def set_prefix_databases(self, relation_id: int, databases: List[str]) -> None: + """Set a coma separated list of databases matching a prefix. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + databases: list of database names matching the requested prefix. + """ + self.update_relation_data(relation_id, {"prefix-databases": ",".join(sorted(databases))}) + def set_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database primary connections. @@ -3588,6 +3625,7 @@ def __init__( requested_entity_secret: Optional[str] = None, requested_entity_name: Optional[str] = None, requested_entity_password: Optional[str] = None, + prefix_matching: Optional[str] = None, ): """Manager of database client relations.""" super().__init__( @@ -3601,6 +3639,7 @@ def __init__( requested_entity_secret, requested_entity_name, requested_entity_password, + prefix_matching, ) self.database = database_name self.relations_aliases = relations_aliases @@ -3700,6 +3739,10 @@ def __init__( f"{relation_alias}_read_only_endpoints_changed", DatabaseReadOnlyEndpointsChangedEvent, ) + self.on.define_event( + f"{relation_alias}_prefix_databases_changed", + DatabasePrefixDatabasesChangedEvent, + ) def _on_secret_changed_event(self, event: SecretChangedEvent): """Event notifying about a new value of a secret.""" @@ -3792,6 +3835,8 @@ def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: event_data["entity-permissions"] = self.relation_data.entity_permissions if self.relation_data.requested_entity_secret: event_data["requested-entity-secret"] = self.relation_data.requested_entity_secret + if self.relation_data.prefix_matching: + event_data["prefix-matching"] = self.relation_data.prefix_matching # Create helper secret if needed if ( @@ -3884,32 +3929,22 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # To avoid unnecessary application restarts do not trigger other events. return - # Emit an endpoints changed event if the database - # added or changed this info in the relation databag. - if "endpoints" in diff.added or "endpoints" in diff.changed: - # Emit the default event (the one without an alias). - logger.info("endpoints changed on %s", datetime.now()) - getattr(self.on, "endpoints_changed").emit( - event.relation, app=event.app, unit=event.unit - ) - - # Emit the aliased event (if any). - self._emit_aliased_event(event, "endpoints_changed") - - # To avoid unnecessary application restarts do not trigger other events. - return - - # Emit a read only endpoints changed event if the database - # added or changed this info in the relation databag. - if "read-only-endpoints" in diff.added or "read-only-endpoints" in diff.changed: - # Emit the default event (the one without an alias). - logger.info("read-only-endpoints changed on %s", datetime.now()) - getattr(self.on, "read_only_endpoints_changed").emit( - event.relation, app=event.app, unit=event.unit - ) - - # Emit the aliased event (if any). - self._emit_aliased_event(event, "read_only_endpoints_changed") + for key, event_name in [ + ("endpoints", "endpoints_changed"), + ("read-only-endpoints", "read_only_endpoints_changed"), + ("prefix-databases", "prefix_databases_changed"), + ]: + # Emit a change event if the key changed. + if key in diff.added or key in diff.changed: + # Emit the default event (the one without an alias). + logger.info("%s changed on %s", key, datetime.now()) + getattr(self.on, event_name).emit(event.relation, app=event.app, unit=event.unit) + + # Emit the aliased event (if any). + self._emit_aliased_event(event, event_name) + + # To avoid unnecessary application restarts do not trigger other events. + return class DatabaseRequires(DatabaseRequirerData, DatabaseRequirerEventHandlers): @@ -3930,6 +3965,7 @@ def __init__( requested_entity_secret: Optional[str] = None, requested_entity_name: Optional[str] = None, requested_entity_password: Optional[str] = None, + prefix_matching: Optional[str] = None, ): DatabaseRequirerData.__init__( self, @@ -3946,6 +3982,7 @@ def __init__( requested_entity_secret, requested_entity_name, requested_entity_password, + prefix_matching, ) DatabaseRequirerEventHandlers.__init__(self, charm, self) diff --git a/auth/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/auth/lib/charms/tls_certificates_interface/v4/tls_certificates.py index 4f80e0c..32b3b15 100644 --- a/auth/lib/charms/tls_certificates_interface/v4/tls_certificates.py +++ b/auth/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -40,17 +40,29 @@ import json import logging import uuid +import warnings from contextlib import suppress from dataclasses import asdict, dataclass, field from datetime import datetime, timedelta, timezone from enum import Enum -from typing import Dict, FrozenSet, List, MutableMapping, Optional, Tuple, Union +from typing import ( + Collection, + Dict, + FrozenSet, + List, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) import pydantic from cryptography import x509 from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes from cryptography.x509.oid import ExtensionOID, NameOID from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent from ops.framework import EventBase, EventSource, Handle, Object @@ -65,7 +77,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 24 +LIBPATCH = 27 PYDEPS = [ "cryptography>=43.0.0", @@ -319,34 +331,60 @@ class Mode(Enum): APP = 2 -@dataclass(frozen=True) class PrivateKey: """This class represents a private key.""" - raw: str + def __init__( + self, raw: Optional[str] = None, x509_object: Optional[rsa.RSAPrivateKey] = None + ) -> None: + """Initialize the PrivateKey object. + + If both raw and x509_object are provided, x509_object takes precedence. + """ + if x509_object: + self._private_key = x509_object + elif raw: + self._private_key = serialization.load_pem_private_key( + raw.encode(), + password=None, + ) + else: + raise ValueError("Either raw private key string or x509_object must be provided") + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the private key.""" + return str(self) def __str__(self): - """Return the private key as a string.""" - return self.raw + """Return the private key as a string in PEM format.""" + return ( + self._private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + .decode() + .strip() + ) + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) @classmethod def from_string(cls, private_key: str) -> "PrivateKey": """Create a PrivateKey object from a private key.""" - return cls(raw=private_key.strip()) + return cls(raw=private_key) def is_valid(self) -> bool: """Validate that the private key is PEM-formatted, RSA, and at least 2048 bits.""" try: - key = serialization.load_pem_private_key( - self.raw.encode(), - password=None, - ) - - if not isinstance(key, rsa.RSAPrivateKey): + if not isinstance(self._private_key, rsa.RSAPrivateKey): logger.warning("Private key is not an RSA key") return False - if key.key_size < 2048: + if self._private_key.key_size < 2048: logger.warning("RSA key size is less than 2048 bits") return False @@ -355,29 +393,180 @@ def is_valid(self) -> bool: logger.warning("Invalid private key format") return False + @classmethod + def generate(cls, key_size: int = 2048, public_exponent: int = 65537) -> "PrivateKey": + """Generate a new RSA private key. + + Args: + key_size: The size of the key in bits. + public_exponent: The public exponent of the key. + + Returns: + PrivateKey: The generated private key. + """ + private_key = rsa.generate_private_key( + public_exponent=public_exponent, + key_size=key_size, + ) + _OWASPLogger().log_event( + event="private_key_generated", + level=logging.INFO, + description="Private key generated", + key_size=str(key_size), + ) + return PrivateKey(x509_object=private_key) + + def __eq__(self, other: object) -> bool: + """Check if two PrivateKey objects are equal.""" + if not isinstance(other, PrivateKey): + return NotImplemented + return self.raw == other.raw + -@dataclass(frozen=True) class Certificate: """This class represents a certificate.""" - raw: str - common_name: str - expiry_time: datetime - validity_start_time: datetime - is_ca: bool = False - sans_dns: Optional[FrozenSet[str]] = frozenset() - sans_ip: Optional[FrozenSet[str]] = frozenset() - sans_oid: Optional[FrozenSet[str]] = frozenset() - email_address: Optional[str] = None - organization: Optional[str] = None - organizational_unit: Optional[str] = None - country_name: Optional[str] = None - state_or_province_name: Optional[str] = None - locality_name: Optional[str] = None + _cert: x509.Certificate + + def __init__( + self, + raw: Optional[str] = None, # Must remain first argument for backwards compatibility + # Old Interface fields (ignored) + common_name: Optional[str] = None, + expiry_time: Optional[datetime] = None, + validity_start_time: Optional[datetime] = None, + is_ca: Optional[bool] = None, + sans_dns: Optional[Set[str]] = None, + sans_ip: Optional[Set[str]] = None, + sans_oid: Optional[Set[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + # End Old Interface fields + x509_object: Optional[x509.Certificate] = None, + ) -> None: + """Initialize the Certificate object. + + This initializer must maintain the old interface while also allowing + instantiation from an existing x509_object. It ignores all fields + other than raw and x509_object, preferring x509_object. + """ + if x509_object: + self._cert = x509_object + elif raw: + self._cert = x509.load_pem_x509_certificate(data=raw.encode()) + else: + raise ValueError("Either raw certificate string or x509_object must be provided") + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the certificate.""" + return str(self) + + @property + def common_name(self) -> str: + """Return the common name of the certificate.""" + # We maintain compatibility with the old interface by returning + # an empty string if no common name is set. + common_name = self._cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + return str(common_name[0].value) if common_name else "" + + @property + def expiry_time(self) -> datetime: + """Return the expiry time of the certificate.""" + return self._cert.not_valid_after_utc + + @property + def validity_start_time(self) -> datetime: + """Return the validity start time of the certificate.""" + return self._cert.not_valid_before_utc + + @property + def is_ca(self) -> bool: + """Return whether the certificate is a CA certificate.""" + try: + return self._cert.extensions.get_extension_for_oid( + ExtensionOID.BASIC_CONSTRAINTS + ).value.ca # type: ignore[reportAttributeAccessIssue] + except x509.ExtensionNotFound: + return False + + @property + def sans_dns(self) -> Optional[Set[str]]: + """Return the DNS Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.DNSName)} + return None + + @property + def sans_ip(self) -> Optional[Set[str]]: + """Return the IP Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} + return None + + @property + def sans_oid(self) -> Optional[Set[str]]: + """Return the OID Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} + return None + + @property + def email_address(self) -> Optional[str]: + """Return the email address of the certificate.""" + email_address = self._cert.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + return str(email_address[0].value) if email_address else None + + @property + def organization(self) -> Optional[str]: + """Return the organization name of the certificate.""" + organization = self._cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + return str(organization[0].value) if organization else None + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name of the certificate.""" + organizational_unit = self._cert.subject.get_attributes_for_oid( + NameOID.ORGANIZATIONAL_UNIT_NAME + ) + return str(organizational_unit[0].value) if organizational_unit else None + + @property + def country_name(self) -> Optional[str]: + """Return the country name of the certificate.""" + country_name = self._cert.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + return str(country_name[0].value) if country_name else None + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name of the certificate.""" + state_or_province_name = self._cert.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + return str(state_or_province_name[0].value) if state_or_province_name else None + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name of the certificate.""" + locality_name = self._cert.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + return str(locality_name[0].value) if locality_name else None def __str__(self) -> str: """Return the certificate as a string.""" - return self.raw + return self._cert.public_bytes(serialization.Encoding.PEM).decode().strip() + + def __eq__(self, other: object) -> bool: + """Check if two Certificate objects are equal.""" + if not isinstance(other, Certificate): + return NotImplemented + return self.raw == other.raw @classmethod def from_string(cls, certificate: str) -> "Certificate": @@ -388,66 +577,7 @@ def from_string(cls, certificate: str) -> "Certificate": logger.error("Could not load certificate: %s", e) raise TLSCertificatesError("Could not load certificate") - common_name = certificate_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME) - country_name = certificate_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) - state_or_province_name = certificate_object.subject.get_attributes_for_oid( - NameOID.STATE_OR_PROVINCE_NAME - ) - locality_name = certificate_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) - organization_name = certificate_object.subject.get_attributes_for_oid( - NameOID.ORGANIZATION_NAME - ) - organizational_unit = certificate_object.subject.get_attributes_for_oid( - NameOID.ORGANIZATIONAL_UNIT_NAME - ) - email_address = certificate_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) - sans_dns: List[str] = [] - sans_ip: List[str] = [] - sans_oid: List[str] = [] - try: - sans = certificate_object.extensions.get_extension_for_class( - x509.SubjectAlternativeName - ).value - for san in sans: - if isinstance(san, x509.DNSName): - sans_dns.append(san.value) - if isinstance(san, x509.IPAddress): - sans_ip.append(str(san.value)) - if isinstance(san, x509.RegisteredID): - sans_oid.append(str(san.value)) - except x509.ExtensionNotFound: - logger.debug("No SANs found in certificate") - sans_dns = [] - sans_ip = [] - sans_oid = [] - expiry_time = certificate_object.not_valid_after_utc - validity_start_time = certificate_object.not_valid_before_utc - is_ca = False - try: - is_ca = certificate_object.extensions.get_extension_for_oid( - ExtensionOID.BASIC_CONSTRAINTS - ).value.ca # type: ignore[reportAttributeAccessIssue] - except x509.ExtensionNotFound: - pass - - return cls( - raw=certificate.strip(), - common_name=str(common_name[0].value), - is_ca=is_ca, - country_name=str(country_name[0].value) if country_name else None, - state_or_province_name=str(state_or_province_name[0].value) - if state_or_province_name - else None, - locality_name=str(locality_name[0].value) if locality_name else None, - organization=str(organization_name[0].value) if organization_name else None, - organizational_unit=str(organizational_unit[0].value) if organizational_unit else None, - email_address=str(email_address[0].value) if email_address else None, - sans_dns=frozenset(sans_dns), - sans_ip=frozenset(sans_ip), - sans_oid=frozenset(sans_oid), - expiry_time=expiry_time, - validity_start_time=validity_start_time, - ) + return cls(x509_object=certificate_object) def matches_private_key(self, private_key: PrivateKey) -> bool: """Check if this certificate matches a given private key. @@ -459,13 +589,8 @@ def matches_private_key(self, private_key: PrivateKey) -> bool: bool: True if the certificate matches the private key, False otherwise. """ try: - cert_object = x509.load_pem_x509_certificate(self.raw.encode()) - key_object = serialization.load_pem_private_key( - private_key.raw.encode(), password=None - ) - - cert_public_key = cert_object.public_key() - key_public_key = key_object.public_key() + cert_public_key = self._cert.public_key() + key_public_key = private_key._private_key.public_key() if not isinstance(cert_public_key, rsa.RSAPublicKey): logger.warning("Certificate does not use RSA public key") @@ -480,84 +605,316 @@ def matches_private_key(self, private_key: PrivateKey) -> bool: logger.warning("Failed to validate certificate and private key match: %s", e) return False + @classmethod + def generate( + cls, + csr: "CertificateSigningRequest", + ca: "Certificate", + ca_private_key: "PrivateKey", + validity: timedelta, + is_ca: bool = False, + ) -> "Certificate": + """Generate a certificate from a CSR signed by the given CA and CA private key. -@dataclass(frozen=True) -class CertificateSigningRequest: - """This class represents a certificate signing request.""" - - raw: str - common_name: str - sans_dns: Optional[FrozenSet[str]] = None - sans_ip: Optional[FrozenSet[str]] = None - sans_oid: Optional[FrozenSet[str]] = None - email_address: Optional[str] = None - organization: Optional[str] = None - organizational_unit: Optional[str] = None - country_name: Optional[str] = None - state_or_province_name: Optional[str] = None - locality_name: Optional[str] = None - has_unique_identifier: bool = False + Args: + csr: The certificate signing request. + ca: The CA certificate. + ca_private_key: The CA private key. + validity: The validity period of the certificate. + is_ca: Whether the generated certificate is a CA certificate. - def __eq__(self, other: object) -> bool: - """Check if two CertificateSigningRequest objects are equal.""" - if not isinstance(other, CertificateSigningRequest): - return NotImplemented - return self.raw.strip() == other.raw.strip() + Returns: + Certificate: The generated certificate. + """ + # Ideally, this would be the constructor, but we can't add new + # required parameters to the constructor without breaking backwards + # compatibility. + private_key = serialization.load_pem_private_key( + str(ca_private_key).encode(), password=None + ) + assert isinstance(private_key, CertificateIssuerPrivateKeyTypes) + + # Create a certificate builder + cert_builder = x509.CertificateBuilder( + subject_name=csr._csr.subject, + # issuer_name=ca._cert.subject, # TODO: Validate this is correct, the old code used `issuer` + issuer_name=ca._cert.issuer, + public_key=csr._csr.public_key(), + serial_number=x509.random_serial_number(), + not_valid_before=datetime.now(timezone.utc), + not_valid_after=datetime.now(timezone.utc) + validity, + ) + extensions = _generate_certificate_request_extensions( + authority_key_identifier=ca._cert.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr._csr, + is_ca=is_ca, + ) + for extension in extensions: + try: + cert_builder = cert_builder.add_extension(extension.value, extension.critical) + except ValueError as e: + logger.error("Could not add extension to certificate: %s", e) + raise TLSCertificatesError("Could not add extension to certificate") from e + + # Sign the certificate with the CA's private key + cert = cert_builder.sign(private_key=private_key, algorithm=hashes.SHA256()) + _OWASPLogger().log_event( + event="certificate_generated", + level=logging.INFO, + description="Certificate generated from CSR", + common_name=csr.common_name, + is_ca=str(is_ca), + validity_days=str(validity.days), + ) - def __str__(self) -> str: - """Return the CSR as a string.""" - return self.raw + return cls(x509_object=cert) @classmethod - def from_string(cls, csr: str) -> "CertificateSigningRequest": - """Create a CertificateSigningRequest object from a CSR.""" - try: - csr_object = x509.load_pem_x509_csr(csr.encode()) - except ValueError as e: - logger.error("Could not load CSR: %s", e) - raise TLSCertificatesError("Could not load CSR") - common_name = csr_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME) - country_name = csr_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) - state_or_province_name = csr_object.subject.get_attributes_for_oid( - NameOID.STATE_OR_PROVINCE_NAME + def generate_self_signed_ca( + cls, + attributes: "CertificateRequestAttributes", + private_key: PrivateKey, + validity: timedelta, + ) -> "Certificate": + """Generate a self-signed CA certificate. + + Args: + attributes: The certificate request attributes. + private_key: The private key to sign the CA certificate. + validity: The validity period of the CA certificate. + + Returns: + Certificate: The generated CA certificate. + """ + assert isinstance(private_key._private_key, rsa.RSAPrivateKey) + + public_key = private_key._private_key.public_key() + + builder = x509.CertificateBuilder( + public_key=public_key, + serial_number=x509.random_serial_number(), + not_valid_before=datetime.now(timezone.utc), + not_valid_after=datetime.now(timezone.utc) + validity, + ) + + if subject_name := _extract_subject_name_attributes(attributes): + builder = builder.subject_name(subject_name).issuer_name(subject_name) + + builder = ( + builder.add_extension( + x509.SubjectKeyIdentifier.from_public_key(public_key), critical=False + ) + .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + ) + + if san_extension := _san_extension( + email_address=attributes.email_address, + sans_dns=attributes.sans_dns, + sans_ip=attributes.sans_ip, + sans_oid=attributes.sans_oid, + ): + builder = builder.add_extension(san_extension, critical=False) + + cert = cls(x509_object=builder.sign(private_key._private_key, algorithm=hashes.SHA256())) + + _OWASPLogger().log_event( + event="ca_certificate_generated", + level=logging.INFO, + description="CA certificate generated", + common_name=cert.common_name, + validity_days=str(validity.days), ) - locality_name = csr_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) - organization_name = csr_object.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) - organizational_unit = csr_object.subject.get_attributes_for_oid( + + return cert + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) + + +class CertificateSigningRequest: + """A representation of the certificate signing request.""" + + _csr: x509.CertificateSigningRequest + + def __init__( + self, + raw: Optional[str] = None, # Must remain first argument for backwards compatibility + # Old Interface fields (ignored) + common_name: Optional[str] = None, + sans_dns: Optional[Set[str]] = None, + sans_ip: Optional[Set[str]] = None, + sans_oid: Optional[Set[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + has_unique_identifier: Optional[bool] = None, + # End Old Interface fields + x509_object: Optional[x509.CertificateSigningRequest] = None, + ): + """Initialize the CertificateSigningRequest object. + + This initializer must maintain the old interface while also allowing + instantiation from an existing x509_object. It ignores all fields + other than raw and x509_object, preferring x509_object. + """ + if x509_object: + self._csr = x509_object + return + elif raw: + try: + self._csr = x509.load_pem_x509_csr(raw.encode()) + except ValueError as e: + logger.error("Could not load CSR: %s", e) + raise TLSCertificatesError("Could not load CSR") + return + raise ValueError("Either raw CSR string or x509_object must be provided") + + @property + def common_name(self) -> str: + """Return the common name of the CSR.""" + common_name = self._csr.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + return str(common_name[0].value) if common_name else "" + + @property + def sans_dns(self) -> Set[str]: + """Return the DNS Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.DNSName)} + return set() + + @property + def sans_ip(self) -> Set[str]: + """Return the IP Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} + return set() + + @property + def sans_oid(self) -> Set[str]: + """Return the OID Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} + return set() + + @property + def email_address(self) -> Optional[str]: + """Return the email address of the CSR.""" + email_address = self._csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + return str(email_address[0].value) if email_address else None + + @property + def organization(self) -> Optional[str]: + """Return the organization name of the CSR.""" + organization = self._csr.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + return str(organization[0].value) if organization else None + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name of the CSR.""" + organizational_unit = self._csr.subject.get_attributes_for_oid( NameOID.ORGANIZATIONAL_UNIT_NAME ) - email_address = csr_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) - unique_identifier = csr_object.subject.get_attributes_for_oid( - NameOID.X500_UNIQUE_IDENTIFIER + return str(organizational_unit[0].value) if organizational_unit else None + + @property + def country_name(self) -> Optional[str]: + """Return the country name of the CSR.""" + country_name = self._csr.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + return str(country_name[0].value) if country_name else None + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name of the CSR.""" + state_or_province_name = self._csr.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME ) - try: - sans = csr_object.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - sans_dns = frozenset(sans.get_values_for_type(x509.DNSName)) - sans_ip = frozenset([str(san) for san in sans.get_values_for_type(x509.IPAddress)]) - sans_oid = frozenset( - [san.dotted_string for san in sans.get_values_for_type(x509.RegisteredID)] - ) - except x509.ExtensionNotFound: - sans = frozenset() - sans_dns = frozenset() - sans_ip = frozenset() - sans_oid = frozenset() - return cls( - raw=csr.strip(), - common_name=str(common_name[0].value), - country_name=str(country_name[0].value) if country_name else None, - state_or_province_name=str(state_or_province_name[0].value) - if state_or_province_name - else None, - locality_name=str(locality_name[0].value) if locality_name else None, - organization=str(organization_name[0].value) if organization_name else None, - organizational_unit=str(organizational_unit[0].value) if organizational_unit else None, - email_address=str(email_address[0].value) if email_address else None, - sans_dns=sans_dns, - sans_ip=sans_ip, - sans_oid=sans_oid, - has_unique_identifier=bool(unique_identifier), + return str(state_or_province_name[0].value) if state_or_province_name else None + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name of the CSR.""" + locality_name = self._csr.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + return str(locality_name[0].value) if locality_name else None + + @property + def has_unique_identifier(self) -> bool: + """Return whether the CSR has a unique identifier.""" + unique_identifier = self._csr.subject.get_attributes_for_oid( + NameOID.X500_UNIQUE_IDENTIFIER ) + return bool(unique_identifier) + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the CSR.""" + return self.__str__() + + def __str__(self) -> str: + """Return the CSR as a string.""" + return self._csr.public_bytes(serialization.Encoding.PEM).decode().strip() + + @property + def additional_critical_extensions(self) -> List[x509.ExtensionType]: + """Return additional critical extensions present on the CSR (excluding SAN).""" + extensions: List[x509.ExtensionType] = [] + for extension in self._csr.extensions: + if extension.critical and extension.oid != ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + extensions.append(extension.value) + return extensions + + @classmethod + def from_string(cls, csr: str) -> "CertificateSigningRequest": + """Create a CertificateSigningRequest object from a CSR.""" + return cls(raw=csr) + + @classmethod + def from_csr(cls, csr: x509.CertificateSigningRequest) -> "CertificateSigningRequest": + """Create a CertificateSigningRequest object from a CSR.""" + return cls(x509_object=csr) + + def __eq__(self, other: object) -> bool: + """Check if two CertificateSigningRequest objects are equal.""" + if not isinstance(other, CertificateSigningRequest): + return NotImplemented + return self.raw == other.raw + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) + + def matches_certificate(self, certificate: Certificate) -> bool: + """Check if this CSR matches a given certificate. + + Args: + certificate (Certificate): The certificate to validate against. + + Returns: + bool: True if the CSR matches the certificate, False otherwise. + """ + return self._csr.public_key() == certificate._cert.public_key() def matches_private_key(self, key: PrivateKey) -> bool: """Check if a CSR matches a private key. @@ -570,12 +927,8 @@ def matches_private_key(self, key: PrivateKey) -> bool: bool: True/False depending on whether the CSR matches the private key. """ try: - csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8")) - key_object = serialization.load_pem_private_key( - data=key.raw.encode("utf-8"), password=None - ) - key_object_public_key = key_object.public_key() - csr_object_public_key = csr_object.public_key() + key_object_public_key = key._private_key.public_key() + csr_object_public_key = self._csr.public_key() if not isinstance(key_object_public_key, rsa.RSAPublicKey): logger.warning("Key is not an RSA key") return False @@ -593,82 +946,191 @@ def matches_private_key(self, key: PrivateKey) -> bool: return False return True - def matches_certificate(self, certificate: Certificate) -> bool: - """Check if a CSR matches a certificate. - - Args: - certificate (Certificate): Certificate - Returns: - bool: True/False depending on whether the CSR matches the certificate. - """ - csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(certificate.raw.encode("utf-8")) - return csr_object.public_key() == cert_object.public_key() - def get_sha256_hex(self) -> str: """Calculate the hash of the provided data and return the hexadecimal representation.""" digest = hashes.Hash(hashes.SHA256()) digest.update(self.raw.encode()) return digest.finalize().hex() + def sign( + self, ca: Certificate, ca_private_key: PrivateKey, validity: timedelta, is_ca: bool = False + ) -> Certificate: + """Sign this CSR with the given CA and CA private key. -@dataclass(frozen=True) -class CertificateRequestAttributes: - """A representation of the certificate request attributes. - - This class should be used inside the requirer charm to specify the requested - attributes for the certificate. - """ - - common_name: str - sans_dns: Optional[FrozenSet[str]] = frozenset() - sans_ip: Optional[FrozenSet[str]] = frozenset() - sans_oid: Optional[FrozenSet[str]] = frozenset() - email_address: Optional[str] = None - organization: Optional[str] = None - organizational_unit: Optional[str] = None - country_name: Optional[str] = None - state_or_province_name: Optional[str] = None - locality_name: Optional[str] = None - is_ca: bool = False - add_unique_id_to_subject_name: bool = True + Args: + ca: The CA certificate. + ca_private_key: The CA private key. + validity: The validity period of the certificate. + is_ca: Whether the generated certificate is a CA certificate. - def is_valid(self) -> bool: - """Check whether the certificate request is valid.""" - if not self.common_name: - return False - return True + Returns: + Certificate: The signed certificate. + """ + return Certificate.generate( + csr=self, + ca=ca, + ca_private_key=ca_private_key, + validity=validity, + is_ca=is_ca, + ) - def generate_csr( - self, + @classmethod + def generate( + cls, + attributes: "CertificateRequestAttributes", private_key: PrivateKey, - ) -> CertificateSigningRequest: - """Generate a CSR using private key and subject. + ) -> "CertificateSigningRequest": + """Generate a CSR using the supplied attributes and private key. Args: + attributes (CertificateRequestAttributes): Certificate request attributes private_key (PrivateKey): Private key - Returns: CertificateSigningRequest: CSR """ - return generate_csr( - private_key=private_key, - common_name=self.common_name, - sans_dns=self.sans_dns, - sans_ip=self.sans_ip, - sans_oid=self.sans_oid, - email_address=self.email_address, - organization=self.organization, - organizational_unit=self.organizational_unit, - country_name=self.country_name, - state_or_province_name=self.state_or_province_name, - locality_name=self.locality_name, - add_unique_id_to_subject_name=self.add_unique_id_to_subject_name, - ) + signing_key = private_key._private_key + assert isinstance(signing_key, CertificateIssuerPrivateKeyTypes) + + csr_builder = x509.CertificateSigningRequestBuilder() + if subject_name := _extract_subject_name_attributes(attributes): + csr_builder = csr_builder.subject_name(subject_name) + + _sans: List[x509.GeneralName] = [] + if attributes.sans_oid: + _sans.extend( + [x509.RegisteredID(x509.ObjectIdentifier(san)) for san in attributes.sans_oid] + ) + if attributes.sans_ip: + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in attributes.sans_ip]) + if attributes.sans_dns: + _sans.extend([x509.DNSName(san) for san in attributes.sans_dns]) + if _sans: + csr_builder = csr_builder.add_extension( + x509.SubjectAlternativeName(set(_sans)), critical=False + ) + if attributes.additional_critical_extensions: + for extension in attributes.additional_critical_extensions: + csr_builder = csr_builder.add_extension(extension, critical=True) + signed_certificate_request = csr_builder.sign(signing_key, hashes.SHA256()) + return cls(x509_object=signed_certificate_request) + + +class CertificateRequestAttributes: + """A representation of the certificate request attributes.""" + + def __init__( + self, + common_name: Optional[str] = None, + sans_dns: Optional[Collection[str]] = None, + sans_ip: Optional[Collection[str]] = None, + sans_oid: Optional[Collection[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + is_ca: bool = False, + add_unique_id_to_subject_name: bool = True, + additional_critical_extensions: Optional[Collection[x509.ExtensionType]] = None, + ): + if not common_name and not sans_dns and not sans_ip and not sans_oid: + raise ValueError( + "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" + ) + self._common_name = common_name + self._sans_dns = set(sans_dns) if sans_dns else None + self._sans_ip = set(sans_ip) if sans_ip else None + self._sans_oid = set(sans_oid) if sans_oid else None + self._email_address = email_address + self._organization = organization + self._organizational_unit = organizational_unit + self._country_name = country_name + self._state_or_province_name = state_or_province_name + self._locality_name = locality_name + self._is_ca = is_ca + self._add_unique_id_to_subject_name = add_unique_id_to_subject_name + self._additional_critical_extensions = list(additional_critical_extensions or []) + + @property + def common_name(self) -> str: + """Return the common name.""" + # For legacy interface compatibility, return empty string if not set + return self._common_name if self._common_name else "" + + @property + def sans_dns(self) -> Optional[Set[str]]: + """Return the DNS Subject Alternative Names.""" + return self._sans_dns + + @property + def sans_ip(self) -> Optional[Set[str]]: + """Return the IP Subject Alternative Names.""" + return self._sans_ip + + @property + def sans_oid(self) -> Optional[Set[str]]: + """Return the OID Subject Alternative Names.""" + return self._sans_oid + + @property + def email_address(self) -> Optional[str]: + """Return the email address.""" + return self._email_address + + @property + def organization(self) -> Optional[str]: + """Return the organization name.""" + return self._organization + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name.""" + return self._organizational_unit + + @property + def country_name(self) -> Optional[str]: + """Return the country name.""" + return self._country_name + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name.""" + return self._state_or_province_name + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name.""" + return self._locality_name + + @property + def is_ca(self) -> bool: + """Return whether the certificate is a CA certificate.""" + return self._is_ca + + @property + def add_unique_id_to_subject_name(self) -> bool: + """Return whether to add a unique identifier to the subject name.""" + return self._add_unique_id_to_subject_name + + @property + def additional_critical_extensions(self) -> List[x509.ExtensionType]: + """Return additional critical extensions to be added to the CSR.""" + return self._additional_critical_extensions @classmethod - def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool): - """Create a CertificateRequestAttributes object from a CSR.""" + def from_csr( + cls, csr: CertificateSigningRequest, is_ca: bool + ) -> "CertificateRequestAttributes": + """Create CertificateRequestAttributes from a CertificateSigningRequest. + + Args: + csr: The CSR to extract attributes from. + is_ca: Whether a CA certificate is being requested. + + Returns: + CertificateRequestAttributes: The extracted attributes. + """ return cls( common_name=csr.common_name, sans_dns=csr.sans_dns, @@ -682,8 +1144,56 @@ def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool): locality_name=csr.locality_name, is_ca=is_ca, add_unique_id_to_subject_name=csr.has_unique_identifier, + additional_critical_extensions=csr.additional_critical_extensions, + ) + + def __eq__(self, other: object) -> bool: + """Check if two CertificateRequestAttributes objects are equal.""" + if not isinstance(other, CertificateRequestAttributes): + return NotImplemented + return ( + self.common_name == other.common_name + and self.sans_dns == other.sans_dns + and self.sans_ip == other.sans_ip + and self.sans_oid == other.sans_oid + and self.email_address == other.email_address + and self.organization == other.organization + and self.organizational_unit == other.organizational_unit + and self.country_name == other.country_name + and self.state_or_province_name == other.state_or_province_name + and self.locality_name == other.locality_name + and self.is_ca == other.is_ca + and self.add_unique_id_to_subject_name == other.add_unique_id_to_subject_name + and self.additional_critical_extensions == other.additional_critical_extensions ) + def is_valid(self) -> bool: + """Validate the attributes of the certificate request. + + Returns: + bool: True if the attributes are valid, False otherwise. + """ + if not self.common_name and not self.sans_dns and not self.sans_ip and not self.sans_oid: + logger.warning( + "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" + ) + return False + return True + + def generate_csr( + self, + private_key: PrivateKey, + ) -> CertificateSigningRequest: + """Generate a CSR using the current attributes and a private key. + + Args: + private_key (PrivateKey): Private key to sign the CSR. + + Returns: + CertificateSigningRequest: The generated CSR. + """ + return CertificateSigningRequest.generate(self, private_key) + @dataclass(frozen=True) class ProviderCertificate: @@ -776,25 +1286,11 @@ def generate_private_key( Returns: PrivateKey: Private Key """ - if key_size < 2048: - raise ValueError("Key size must be at least 2048 bits for RSA security") - private_key = rsa.generate_private_key( - public_exponent=public_exponent, - key_size=key_size, + warnings.warn( + "generate_private_key() is deprecated. Use PrivateKey.generate() instead.", + DeprecationWarning, ) - key_bytes = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - key = PrivateKey.from_string(key_bytes.decode()) - _OWASPLogger().log_event( - event="private_key_generated", - level=logging.INFO, - description="Private key generated", - key_size=str(key_size), - ) - return key + return PrivateKey.generate(key_size=key_size, public_exponent=public_exponent) def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime: @@ -875,43 +1371,23 @@ def generate_csr( # noqa: C901 Returns: CertificateSigningRequest: CSR """ - signing_key = serialization.load_pem_private_key(str(private_key).encode(), password=None) - subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)] - if add_unique_id_to_subject_name: - unique_identifier = uuid.uuid4() - subject_name.append( - x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) - ) - if organization: - subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization)) - if organizational_unit: - subject_name.append( - x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit) - ) - if email_address: - subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) - if country_name: - subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) - if state_or_province_name: - subject_name.append( - x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) - ) - if locality_name: - subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) - csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) - - _sans: List[x509.GeneralName] = [] - if sans_oid: - _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) - if sans_ip: - _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) - if sans_dns: - _sans.extend([x509.DNSName(san) for san in sans_dns]) - if _sans: - csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) - signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type] - csr_str = signed_certificate.public_bytes(serialization.Encoding.PEM).decode() - return CertificateSigningRequest.from_string(csr_str) + warnings.warn( + "generate_csr() is deprecated. Use CertificateRequestAttributes.generate_csr() or CertificateSigningRequest.generate() instead.", + DeprecationWarning, + ) + return CertificateRequestAttributes( + common_name=common_name, + sans_dns=sans_dns, + sans_ip=sans_ip, + sans_oid=sans_oid, + organization=organization, + organizational_unit=organizational_unit, + email_address=email_address, + country_name=country_name, + state_or_province_name=state_or_province_name, + locality_name=locality_name, + add_unique_id_to_subject_name=add_unique_id_to_subject_name, + ).generate_csr(private_key=private_key) def generate_ca( @@ -931,108 +1407,47 @@ def generate_ca( """Generate a self signed CA Certificate. Args: - private_key (PrivateKey): Private key - validity (timedelta): Certificate validity time - common_name (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). - sans_dns (FrozenSet[str]): DNS Subject Alternative Names - sans_ip (FrozenSet[str]): IP Subject Alternative Names - sans_oid (FrozenSet[str]): OID Subject Alternative Names - organization (Optional[str]): Organization name - organizational_unit (Optional[str]): Organizational unit name - email_address (Optional[str]): Email address - country_name (str): Certificate Issuing country - state_or_province_name (str): Certificate Issuing state or province - locality_name (str): Certificate Issuing locality + private_key: Private key + validity: Certificate validity time + common_name: Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + sans_dns: DNS Subject Alternative Names + sans_ip: IP Subject Alternative Names + sans_oid: OID Subject Alternative Names + organization: Organization name + organizational_unit: Organizational unit name + email_address: Email address + country_name: Certificate Issuing country + state_or_province_name: Certificate Issuing state or province + locality_name: Certificate Issuing locality Returns: - Certificate: CA Certificate. + CA Certificate. """ - private_key_object = serialization.load_pem_private_key( - str(private_key).encode(), password=None - ) - assert isinstance(private_key_object, rsa.RSAPrivateKey) - subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)] - if organization: - subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization)) - if organizational_unit: - subject_name.append( - x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit) - ) - if email_address: - subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) - if country_name: - subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) - if state_or_province_name: - subject_name.append( - x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) - ) - if locality_name: - subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) - - subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key( - private_key_object.public_key() - ) - subject_identifier = key_identifier = subject_identifier_object.public_bytes() - key_usage = x509.KeyUsage( - digital_signature=True, - key_encipherment=True, - key_cert_sign=True, - key_agreement=False, - content_commitment=False, - data_encipherment=False, - crl_sign=False, - encipher_only=False, - decipher_only=False, + warnings.warn( + "generate_ca() is deprecated. Use Certificate.generate_self_signed_ca() instead.", + DeprecationWarning, ) - - builder = ( - x509.CertificateBuilder() - .subject_name(x509.Name(subject_name)) - .issuer_name(x509.Name(subject_name)) - .public_key(private_key_object.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + validity) - .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) - .add_extension( - x509.AuthorityKeyIdentifier( - key_identifier=key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ) - .add_extension(key_usage, critical=True) - .add_extension( - x509.BasicConstraints(ca=True, path_length=None), - critical=True, - ) - ) - san_extension = _san_extension( - email_address=email_address, + attributes = CertificateRequestAttributes( + common_name=common_name, sans_dns=sans_dns, sans_ip=sans_ip, sans_oid=sans_oid, + organization=organization, + organizational_unit=organizational_unit, + email_address=email_address, + country_name=country_name, + state_or_province_name=state_or_province_name, + locality_name=locality_name, + is_ca=True, ) - if san_extension: - builder = builder.add_extension(san_extension, critical=False) - cert = builder.sign(private_key_object, hashes.SHA256()) # type: ignore[arg-type] - ca_cert_str = cert.public_bytes(serialization.Encoding.PEM).decode().strip() - _OWASPLogger().log_event( - event="ca_certificate_generated", - level=logging.INFO, - description="CA certificate generated", - common_name=common_name, - validity_days=str(validity.days), - ) - return Certificate.from_string(ca_cert_str) + return Certificate.generate_self_signed_ca(attributes, private_key, validity) def _san_extension( email_address: Optional[str] = None, - sans_dns: Optional[FrozenSet[str]] = frozenset(), - sans_ip: Optional[FrozenSet[str]] = frozenset(), - sans_oid: Optional[FrozenSet[str]] = frozenset(), + sans_dns: Optional[Collection[str]] = frozenset(), + sans_ip: Optional[Collection[str]] = frozenset(), + sans_oid: Optional[Collection[str]] = frozenset(), ) -> Optional[x509.SubjectAlternativeName]: sans: List[x509.GeneralName] = [] if email_address: @@ -1068,48 +1483,67 @@ def generate_certificate( Returns: Certificate: Certificate """ - csr_object = x509.load_pem_x509_csr(str(csr).encode()) - subject = csr_object.subject - ca_pem = x509.load_pem_x509_certificate(str(ca).encode()) - issuer = ca_pem.issuer - private_key = serialization.load_pem_private_key(str(ca_private_key).encode(), password=None) - - certificate_builder = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(csr_object.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + validity) + warnings.warn( + "generate_certificate() is deprecated. Use Certificate.generate() instead.", + DeprecationWarning, ) - extensions = _generate_certificate_request_extensions( - authority_key_identifier=ca_pem.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - csr=csr_object, + return Certificate.generate( + csr=csr, + ca=ca, + ca_private_key=ca_private_key, + validity=validity, is_ca=is_ca, ) - for extension in extensions: - try: - certificate_builder = certificate_builder.add_extension( - extval=extension.value, - critical=extension.critical, + + +def _extract_subject_name_attributes( + attributes: CertificateRequestAttributes, +) -> Optional[x509.Name]: + subject_name_attributes = [] + if attributes.common_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.COMMON_NAME, attributes.common_name) + ) + if attributes.add_unique_id_to_subject_name: + unique_identifier = uuid.uuid4() + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) + ) + if attributes.organization: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, attributes.organization) + ) + if attributes.organizational_unit: + subject_name_attributes.append( + x509.NameAttribute( + x509.NameOID.ORGANIZATIONAL_UNIT_NAME, + attributes.organizational_unit, ) - except ValueError as e: - logger.warning("Failed to add extension %s: %s", extension.oid, e) - - cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] - cert_bytes = cert.public_bytes(serialization.Encoding.PEM) - _OWASPLogger().log_event( - event="certificate_generated", - level=logging.INFO, - description="Certificate generated from CSR", - common_name=csr.common_name, - is_ca=str(is_ca), - validity_days=str(validity.days), - ) - return Certificate.from_string(cert_bytes.decode().strip()) + ) + if attributes.email_address: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, attributes.email_address) + ) + if attributes.country_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, attributes.country_name) + ) + if attributes.state_or_province_name: + subject_name_attributes.append( + x509.NameAttribute( + x509.NameOID.STATE_OR_PROVINCE_NAME, + attributes.state_or_province_name, + ) + ) + if attributes.locality_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.LOCALITY_NAME, attributes.locality_name) + ) + + if subject_name_attributes: + return x509.Name(subject_name_attributes) + + return None def _generate_certificate_request_extensions(