Skip to content

Commit

Permalink
TLSA tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
nils-wisiol committed Sep 19, 2021
1 parent 2f9e910 commit bf57f20
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 19 deletions.
34 changes: 20 additions & 14 deletions api/desecapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ def filter_qname(self, qname: str, **kwargs) -> models.query.QuerySet:
).filter(dotted_qname__endswith=F('dotted_name'), **kwargs)

def most_specific_zone(self, fqdn: str) -> Tuple[Domain, str]:
domain = self.filter_qname(fqdn).order_by('-name_length').first()
try:
domain = self.filter_qname(fqdn).order_by('-name_length')[0]
except IndexError:
raise Domain.DoesNotExist
subname = fqdn[:-len(domain.name)].rstrip('.')
return domain, subname

Expand Down Expand Up @@ -999,7 +1002,7 @@ class Identity(models.Model):
created = models.DateTimeField(auto_now_add=True)
owner = models.ForeignKey(User, on_delete=models.PROTECT, related_name='identities')
default_ttl = models.PositiveIntegerField(default=300)
rrs = models.ManyToManyField(to=RR)
rrs = models.ManyToManyField(to=RR, related_name='identities')
scheduled_removal = models.DateTimeField(null=True)

class Meta:
Expand All @@ -1010,14 +1013,14 @@ def get_rrs(self) -> List[RR]:

def save(self, *args, **kwargs):
for rr in self.get_rrs():
self.rrs.add(rr)
rr.rrset.save()
rr.save()
self.rrs.add(rr)
return super().save(*args, **kwargs)

def delete(self, using=None, keep_parents=False):
for rr in self.rrs.all(): # TODO use one query
if len(rr.identities) == 1:
if len(rr.identities.all()) == 1:
rr.delete()
return super().delete(using, keep_parents)

Expand Down Expand Up @@ -1076,7 +1079,7 @@ def __init__(self, *args, **kwargs):
if 'not_valid_after' not in kwargs:
self.scheduled_removal = self.not_valid_after

def get_record_contents(self) -> List[str]:
def get_record_content(self) -> str:
# choose hash function
if self.tlsa_matching_type == self.MatchingType.SHA256:
hash_function = hazmat.primitives.hashes.SHA256()
Expand All @@ -1100,7 +1103,7 @@ def get_record_contents(self) -> List[str]:
hash = h.finalize().hex()

# create TLSA record content
return [f"{self.tlsa_certificate_usage} {self.tlsa_selector} {self.tlsa_matching_type} {hash}"]
return f"{self.tlsa_certificate_usage} {self.tlsa_selector} {self.tlsa_matching_type} {hash}"

@property
def _cert(self) -> x509.Certificate:
Expand Down Expand Up @@ -1145,14 +1148,17 @@ def subject_names_clean(self) -> Set[str]:
return clean

def get_rrs(self) -> List[RR]:
return [
self.get_or_create_rr(
fqdn=f"_{self.port:n}._{self.protocol}.{qname}",
content=content,
)
for qname in self.subject_names_clean
for content in self.get_record_contents()
]
rrs = []
content = self.get_record_content()
for qname in self.subject_names_clean:
try:
rrs.append(self.get_or_create_rr(
fqdn=f"_{self.port:n}._{self.protocol}.{qname}",
content=content,
))
except Domain.DoesNotExist:
pass
return rrs

@property
def not_valid_before(self):
Expand Down
40 changes: 35 additions & 5 deletions api/desecapi/tests/test_identities.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def test_generated_rrs_many_rrsets(self):

id = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user, protocol=models.TLSIdentity.Protocol.SCTP)

self.assertEqual(
id.domains_subnames(),
{(domain, '_443._sctp'), (domain, '_443._sctp.desec'), (domain, '_443._sctp.dedyn')},
)
self.assertEqual(id.subject_names, SUBJECT_NAMES)

rrs = id.get_rrs()
self.assertEqual(len(rrs), 3)
Expand All @@ -69,7 +66,6 @@ def test_generated_rrs_one_rrset(self):
domain.save()

id = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user, port=123)
self.assertEqual(id.domains_subnames(), {(domain, '_123._tcp')})

rrs = id.get_rrs()
self.assertEqual(len(rrs), 1)
Expand Down Expand Up @@ -115,3 +111,37 @@ def test_create_delete_rrs(self):
id.delete()
rrset = models.RRset.objects.get(domain__name='desec.example.dedyn.io', type='TLSA', subname='_123._tcp')
self.assertEqual(len(rrset.records.all()), 1)

def test_duplicate_record(self):
def count_tlsa_records():
return models.RRset.objects.get(
domain__name='desec.example.dedyn.io',
type='TLSA', subname='_443._tcp'
).records.count()

domain = models.Domain(name='desec.example.dedyn.io', owner=self.user)
domain.save()

# insert first cert, insert second, delete first, delete second
id1 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user)
id2 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user)
id1.save()
self.assertEqual(count_tlsa_records(), 1)
id2.save()
self.assertEqual(count_tlsa_records(), 1)
id1.delete()
self.assertEqual(count_tlsa_records(), 1)
id2.delete()
self.assertEqual(count_tlsa_records(), 0)

# insert first cert, insert second, delete second, delete first
id1 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user)
id2 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user)
id1.save()
self.assertEqual(count_tlsa_records(), 1)
id2.save()
self.assertEqual(count_tlsa_records(), 1)
id2.delete()
self.assertEqual(count_tlsa_records(), 1)
id1.delete()
self.assertEqual(count_tlsa_records(), 0)

0 comments on commit bf57f20

Please sign in to comment.