Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor dns to improve testability #2224

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions app/custom_domain_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
get_cname_record,
get_mx_domains,
get_txt_record,
DNSClient,
is_mx_equivalent,
get_spf_domain,
get_network_dns_client,
)
from app.models import CustomDomain
from dataclasses import dataclass
Expand All @@ -19,10 +17,13 @@ class DomainValidationResult:


class CustomDomainValidation:
def __init__(self, dkim_domain: str):
def __init__(
self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
):
self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._dkim_records = {
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}

Expand All @@ -38,15 +39,31 @@ def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
Check if dkim records are properly set for this custom domain.
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
"""
correct_records = {}
invalid_records = {}
for prefix, expected_record in self.get_dkim_records():
expected_records = self.get_dkim_records()
for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = get_cname_record(custom_record)
if dkim_record != expected_record:
dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record == expected_record:
correct_records[prefix] = custom_record
else:
invalid_records[custom_record] = dkim_record or "empty"
# HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES

# HACK
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
# the domain validated to continue seeing it as validated (although showing them the missing records).
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
# we will remove the dkim_verified flag.
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
if custom_domain.dkim_verified:
return invalid_records
# Check if at least the original dkim is there
if correct_records.get("dkim._domainkey") is not None:
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
return invalid_records

# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
# rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit()
return invalid_records
Expand All @@ -57,7 +74,7 @@ def validate_domain_ownership(
"""
Check if the custom_domain has added the ownership verification records
"""
txt_records = get_txt_record(custom_domain.domain)
txt_records = self._dns_client.get_txt_record(custom_domain.domain)

if custom_domain.get_ownership_dns_txt_value() in txt_records:
custom_domain.ownership_verified = True
Expand All @@ -69,7 +86,7 @@ def validate_domain_ownership(
def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = get_mx_domains(custom_domain.domain)
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)

if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult(
Expand All @@ -84,7 +101,7 @@ def validate_mx_records(
def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = get_spf_domain(custom_domain.domain)
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
Expand All @@ -93,13 +110,14 @@ def validate_spf_records(
custom_domain.spf_verified = False
Session.commit()
return DomainValidationResult(
success=False, errors=get_txt_record(custom_domain.domain)
success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain),
)

def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
Expand Down
210 changes: 117 additions & 93 deletions app/dns_utils.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,144 @@
from app import config
from typing import Optional, List, Tuple
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional

import dns.resolver

from app.config import NAMESERVERS

def _get_dns_resolver():
my_resolver = dns.resolver.Resolver()
my_resolver.nameservers = config.NAMESERVERS

return my_resolver


def get_ns(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
except Exception:
return []
return [a.to_text() for a in answers]


def get_cname_record(hostname) -> Optional[str]:
"""Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
try:
answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
except Exception:
return None

for a in answers:
ret = a.to_text()
return ret[:-1]

return None
_include_spf = "include:"


def get_mx_domains(hostname) -> [(int, str)]:
"""return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
) -> bool:
"""
try:
answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
except Exception:
return []

ret = []

for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")

ret.append((int(parts[0]), parts[1]))

return sorted(ret, key=lambda prio_domain: prio_domain[0])

Compare mx_domains with ref_mx_domains to see if they are equivalent.
mx_domains and ref_mx_domains are list of (priority, domain)

_include_spf = "include:"
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda x: x[0])
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])

if len(mx_domains) < len(ref_mx_domains):
return False

def get_spf_domain(hostname) -> [str]:
"""return all domains listed in *include:*"""
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
for i in range(len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]:
return False

ret = []
return True

for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes

if record.startswith("v=spf1"):
class DNSClient(ABC):
@abstractmethod
def get_cname_record(self, hostname: str) -> Optional[str]:
pass

@abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
pass

def get_spf_domain(self, hostname: str) -> List[str]:
"""
return all domains listed in *include:*
"""
try:
records = self.get_txt_record(hostname)
ret = []
for record in records:
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(
part[part.find(_include_spf) + len(_include_spf) :]
)
return ret
except Exception:
return []

@abstractmethod
def get_txt_record(self, hostname: str) -> List[str]:
pass


class NetworkDNSClient(DNSClient):
def __init__(self, nameservers: List[str]):
self._resolver = dns.resolver.Resolver()
self._resolver.nameservers = nameservers

def get_cname_record(self, hostname: str) -> Optional[str]:
"""
Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
"""
try:
answers = self._resolver.resolve(hostname, "CNAME", search=True)
for a in answers:
ret = a.to_text()
return ret[:-1]
except Exception:
return None

def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = self._resolver.resolve(hostname, "MX", search=True)
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(part[part.find(_include_spf) + len(_include_spf) :])
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda x: x[0])
except Exception:
return []

return ret
def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
ret.append(record.decode())
return ret
except Exception:
return []


def get_txt_record(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}

ret = []
def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname

for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
self.mx_records[hostname] = mx_list

ret.append(record)
def set_txt_record(self, hostname: str, txt_list: List[str]):
self.txt_records[hostname] = txt_list

return ret
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)

def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])

def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
mx_domains and ref_mx_domains are list of (priority, domain)
def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])

The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0])
ref_mx_domains = sorted(
ref_mx_domains, key=lambda priority_domain: priority_domain[0]
)

if len(mx_domains) < len(ref_mx_domains):
return False
def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)

for i in range(0, len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]:
return False

return True
def get_mx_domains(hostname: str) -> [(int, str)]:
return get_network_dns_client().get_mx_domains(hostname)
2 changes: 1 addition & 1 deletion templates/dashboard/domain_detail/dns.html
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ <h1 class="h2">{{ custom_domain.domain }}</h1>
folder.
</div>
<div class="mb-2">Add the following CNAME DNS records to your domain.</div>
{% for dkim_prefix, dkim_cname_value in dkim_records %}
{% for dkim_prefix, dkim_cname_value in dkim_records.items() %}

<div class="mb-2 p-3 dns-record">
Record: CNAME
Expand Down
Loading
Loading