Skip to content

Commit

Permalink
chore: adapt code calling DNS and add tests to it
Browse files Browse the repository at this point in the history
  • Loading branch information
cquintana92 committed Sep 17, 2024
1 parent 0dd202d commit 4cd3a40
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 32 deletions.
4 changes: 2 additions & 2 deletions app/custom_domain_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._dkim_records = {
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}

Expand All @@ -40,7 +40,7 @@ def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
"""
invalid_records = {}
for prefix, expected_record in self.get_dkim_records():
for prefix, expected_record in self.get_dkim_records().items():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record != expected_record:
Expand Down
46 changes: 17 additions & 29 deletions app/dns_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,24 @@ def get_cname_record(self, hostname: str) -> Optional[str]:
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
pass

@abstractmethod
def get_spf_domain(self, hostname: str) -> List[str]:
pass
"""
return all domains listed in *include:*
"""
try:
records = self.get_txt_record(hostname)
ret = []
for record in records:
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(
part[part.find(_include_spf) + len(_include_spf) :]
)
return ret
except Exception:
return []

@abstractmethod
def get_txt_record(self, hostname: str) -> List[str]:
Expand Down Expand Up @@ -82,27 +97,6 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
except Exception:
return []

def get_spf_domain(self, hostname: str) -> List[str]:
"""
return all domains listed in *include:*
"""
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record_str = record.decode() # record is bytes
if record_str.startswith("v=spf1"):
parts = record_str.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(
part[part.find(_include_spf) + len(_include_spf) :]
)
return ret
except Exception:
return []

def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
Expand All @@ -128,9 +122,6 @@ def set_cname_record(self, hostname: str, cname: str):
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
self.mx_records[hostname] = mx_list

def set_spf_domain(self, hostname: str, spf_list: List[str]):
self.spf_records[hostname] = spf_list

def set_txt_record(self, hostname: str, txt_list: List[str]):
self.txt_records[hostname] = txt_list

Expand All @@ -141,9 +132,6 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])

def get_spf_domain(self, hostname: str) -> List[str]:
return self.spf_records.get(hostname, [])

def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])

Expand Down
2 changes: 1 addition & 1 deletion templates/dashboard/domain_detail/dns.html
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ <h1 class="h2">{{ custom_domain.domain }}</h1>
folder.
</div>
<div class="mb-2">Add the following CNAME DNS records to your domain.</div>
{% for dkim_prefix, dkim_cname_value in dkim_records %}
{% for dkim_prefix, dkim_cname_value in dkim_records.items() %}

<div class="mb-2 p-3 dns-record">
Record: CNAME
Expand Down
297 changes: 297 additions & 0 deletions tests/test_custom_domain_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
from typing import Optional

from app import config
from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient
from app.utils import random_string
from tests.utils import create_new_user, random_domain

user: Optional[User] = None


def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()


def create_custom_domain(domain: str) -> CustomDomain:
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)


def test_custom_domain_validation_get_dkim_records():
domain = random_domain()
validator = CustomDomainValidation(domain)
records = validator.get_dkim_records()

assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"


# validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())
res = validator.validate_dkim_records(domain)

assert len(res) == 3
for record_value in res.values():
assert record_value == "empty"

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False


def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)

user_domain = random_domain()

# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")

domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)

assert len(res) == 2
for record_value in res.values():
assert record_value == "wrong"

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False


def test_custom_domain_validation_validate_dkim_records_success():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)

user_domain = random_domain()

# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
)

domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True


# validate_ownership
def test_custom_domain_validation_validate_ownership_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())
res = validator.validate_domain_ownership(domain)

assert res.success is False
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False


def test_custom_domain_validation_validate_ownership_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())

wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_domain_ownership(domain)

assert res.success is False
assert res.errors == wrong_records

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False


def test_custom_domain_validation_validate_ownership_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())

dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
res = validator.validate_domain_ownership(domain)

assert res.success is True
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True


# validate_mx_records
def test_custom_domain_validation_validate_mx_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())
res = validator.validate_mx_records(domain)

assert res.success is False
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False


def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())

wrong_record_1 = random_string()
wrong_record_2 = random_string()
wrong_records = [(10, wrong_record_1), (20, wrong_record_2)]
dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain)

assert res.success is False
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False


def test_custom_domain_validation_validate_mx_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())

dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY)
res = validator.validate_mx_records(domain)

assert res.success is True
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is True


# validate_spf_records
def test_custom_domain_validation_validate_spf_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())
res = validator.validate_spf_records(domain)

assert res.success is False
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False


def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())

wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_spf_records(domain)

assert res.success is False
assert res.errors == wrong_records

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False


def test_custom_domain_validation_validate_spf_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())

dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
res = validator.validate_spf_records(domain)

assert res.success is True
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True


# validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())
res = validator.validate_dmarc_records(domain)

assert res.success is False
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False


def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())

wrong_records = [random_string()]
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
res = validator.validate_dmarc_records(domain)

assert res.success is False
assert res.errors == wrong_records

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False


def test_custom_domain_validation_validate_dmarc_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)

domain = create_custom_domain(random_domain())

dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
res = validator.validate_dmarc_records(domain)

assert res.success is True
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is True
Loading

0 comments on commit 4cd3a40

Please sign in to comment.