From bf57f20c40c4612594d2eb6d201ab806cafe3c4d Mon Sep 17 00:00:00 2001 From: Nils Wisiol Date: Sun, 19 Sep 2021 22:44:25 +0200 Subject: [PATCH] TLSA tests pass --- api/desecapi/models.py | 34 +++++++++++++---------- api/desecapi/tests/test_identities.py | 40 +++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/api/desecapi/models.py b/api/desecapi/models.py index 70c1cf0e3..29168fe14 100644 --- a/api/desecapi/models.py +++ b/api/desecapi/models.py @@ -220,7 +220,10 @@ def filter_qname(self, qname: str, **kwargs) -> models.query.QuerySet: ).filter(dotted_qname__endswith=F('dotted_name'), **kwargs) def most_specific_zone(self, fqdn: str) -> Tuple[Domain, str]: - domain = self.filter_qname(fqdn).order_by('-name_length').first() + try: + domain = self.filter_qname(fqdn).order_by('-name_length')[0] + except IndexError: + raise Domain.DoesNotExist subname = fqdn[:-len(domain.name)].rstrip('.') return domain, subname @@ -999,7 +1002,7 @@ class Identity(models.Model): created = models.DateTimeField(auto_now_add=True) owner = models.ForeignKey(User, on_delete=models.PROTECT, related_name='identities') default_ttl = models.PositiveIntegerField(default=300) - rrs = models.ManyToManyField(to=RR) + rrs = models.ManyToManyField(to=RR, related_name='identities') scheduled_removal = models.DateTimeField(null=True) class Meta: @@ -1010,14 +1013,14 @@ def get_rrs(self) -> List[RR]: def save(self, *args, **kwargs): for rr in self.get_rrs(): - self.rrs.add(rr) rr.rrset.save() rr.save() + self.rrs.add(rr) return super().save(*args, **kwargs) def delete(self, using=None, keep_parents=False): for rr in self.rrs.all(): # TODO use one query - if len(rr.identities) == 1: + if len(rr.identities.all()) == 1: rr.delete() return super().delete(using, keep_parents) @@ -1076,7 +1079,7 @@ def __init__(self, *args, **kwargs): if 'not_valid_after' not in kwargs: self.scheduled_removal = self.not_valid_after - def get_record_contents(self) -> List[str]: + def get_record_content(self) -> str: # choose hash function if self.tlsa_matching_type == self.MatchingType.SHA256: hash_function = hazmat.primitives.hashes.SHA256() @@ -1100,7 +1103,7 @@ def get_record_contents(self) -> List[str]: hash = h.finalize().hex() # create TLSA record content - return [f"{self.tlsa_certificate_usage} {self.tlsa_selector} {self.tlsa_matching_type} {hash}"] + return f"{self.tlsa_certificate_usage} {self.tlsa_selector} {self.tlsa_matching_type} {hash}" @property def _cert(self) -> x509.Certificate: @@ -1145,14 +1148,17 @@ def subject_names_clean(self) -> Set[str]: return clean def get_rrs(self) -> List[RR]: - return [ - self.get_or_create_rr( - fqdn=f"_{self.port:n}._{self.protocol}.{qname}", - content=content, - ) - for qname in self.subject_names_clean - for content in self.get_record_contents() - ] + rrs = [] + content = self.get_record_content() + for qname in self.subject_names_clean: + try: + rrs.append(self.get_or_create_rr( + fqdn=f"_{self.port:n}._{self.protocol}.{qname}", + content=content, + )) + except Domain.DoesNotExist: + pass + return rrs @property def not_valid_before(self): diff --git a/api/desecapi/tests/test_identities.py b/api/desecapi/tests/test_identities.py index 2d00cbccd..3ce7bec1f 100644 --- a/api/desecapi/tests/test_identities.py +++ b/api/desecapi/tests/test_identities.py @@ -47,10 +47,7 @@ def test_generated_rrs_many_rrsets(self): id = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user, protocol=models.TLSIdentity.Protocol.SCTP) - self.assertEqual( - id.domains_subnames(), - {(domain, '_443._sctp'), (domain, '_443._sctp.desec'), (domain, '_443._sctp.dedyn')}, - ) + self.assertEqual(id.subject_names, SUBJECT_NAMES) rrs = id.get_rrs() self.assertEqual(len(rrs), 3) @@ -69,7 +66,6 @@ def test_generated_rrs_one_rrset(self): domain.save() id = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user, port=123) - self.assertEqual(id.domains_subnames(), {(domain, '_123._tcp')}) rrs = id.get_rrs() self.assertEqual(len(rrs), 1) @@ -115,3 +111,37 @@ def test_create_delete_rrs(self): id.delete() rrset = models.RRset.objects.get(domain__name='desec.example.dedyn.io', type='TLSA', subname='_123._tcp') self.assertEqual(len(rrset.records.all()), 1) + + def test_duplicate_record(self): + def count_tlsa_records(): + return models.RRset.objects.get( + domain__name='desec.example.dedyn.io', + type='TLSA', subname='_443._tcp' + ).records.count() + + domain = models.Domain(name='desec.example.dedyn.io', owner=self.user) + domain.save() + + # insert first cert, insert second, delete first, delete second + id1 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user) + id2 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user) + id1.save() + self.assertEqual(count_tlsa_records(), 1) + id2.save() + self.assertEqual(count_tlsa_records(), 1) + id1.delete() + self.assertEqual(count_tlsa_records(), 1) + id2.delete() + self.assertEqual(count_tlsa_records(), 0) + + # insert first cert, insert second, delete second, delete first + id1 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user) + id2 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user) + id1.save() + self.assertEqual(count_tlsa_records(), 1) + id2.save() + self.assertEqual(count_tlsa_records(), 1) + id2.delete() + self.assertEqual(count_tlsa_records(), 1) + id1.delete() + self.assertEqual(count_tlsa_records(), 0)