diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index 58b698b13..e9b26825e 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -23,7 +23,7 @@ def __init__( self.dkim_domain = dkim_domain self._dns_client = dns_client self._dkim_records = { - (f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}") + f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}" for key in ("dkim", "dkim02", "dkim03") } @@ -40,7 +40,7 @@ def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]: Returns empty list if all records are ok. Other-wise return the records that aren't properly configured """ invalid_records = {} - for prefix, expected_record in self.get_dkim_records(): + for prefix, expected_record in self.get_dkim_records().items(): custom_record = f"{prefix}.{custom_domain.domain}" dkim_record = self._dns_client.get_cname_record(custom_record) if dkim_record != expected_record: diff --git a/app/dns_utils.py b/app/dns_utils.py index d55fd6f4f..2ce699340 100644 --- a/app/dns_utils.py +++ b/app/dns_utils.py @@ -40,9 +40,24 @@ def get_cname_record(self, hostname: str) -> Optional[str]: def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: pass - @abstractmethod def get_spf_domain(self, hostname: str) -> List[str]: - pass + """ + return all domains listed in *include:* + """ + try: + records = self.get_txt_record(hostname) + ret = [] + for record in records: + if record.startswith("v=spf1"): + parts = record.split(" ") + for part in parts: + if part.startswith(_include_spf): + ret.append( + part[part.find(_include_spf) + len(_include_spf) :] + ) + return ret + except Exception: + return [] @abstractmethod def get_txt_record(self, hostname: str) -> List[str]: @@ -82,27 +97,6 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: except Exception: return [] - def get_spf_domain(self, hostname: str) -> List[str]: - """ - return all domains listed in *include:* - """ - try: - answers = self._resolver.resolve(hostname, "TXT", search=True) - ret = [] - for a in answers: # type: dns.rdtypes.ANY.TXT.TXT - for record in a.strings: - record_str = record.decode() # record is bytes - if record_str.startswith("v=spf1"): - parts = record_str.split(" ") - for part in parts: - if part.startswith(_include_spf): - ret.append( - part[part.find(_include_spf) + len(_include_spf) :] - ) - return ret - except Exception: - return [] - def get_txt_record(self, hostname: str) -> List[str]: try: answers = self._resolver.resolve(hostname, "TXT", search=True) @@ -128,9 +122,6 @@ def set_cname_record(self, hostname: str, cname: str): def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]): self.mx_records[hostname] = mx_list - def set_spf_domain(self, hostname: str, spf_list: List[str]): - self.spf_records[hostname] = spf_list - def set_txt_record(self, hostname: str, txt_list: List[str]): self.txt_records[hostname] = txt_list @@ -141,9 +132,6 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: mx_list = self.mx_records.get(hostname, []) return sorted(mx_list, key=lambda x: x[0]) - def get_spf_domain(self, hostname: str) -> List[str]: - return self.spf_records.get(hostname, []) - def get_txt_record(self, hostname: str) -> List[str]: return self.txt_records.get(hostname, []) diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html index 15ef346f7..810aa302a 100644 --- a/templates/dashboard/domain_detail/dns.html +++ b/templates/dashboard/domain_detail/dns.html @@ -237,7 +237,7 @@

{{ custom_domain.domain }}

folder.
Add the following CNAME DNS records to your domain.
- {% for dkim_prefix, dkim_cname_value in dkim_records %} + {% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
Record: CNAME diff --git a/tests/test_custom_domain_validation.py b/tests/test_custom_domain_validation.py new file mode 100644 index 000000000..0a5cafd75 --- /dev/null +++ b/tests/test_custom_domain_validation.py @@ -0,0 +1,297 @@ +from typing import Optional + +from app import config +from app.constants import DMARC_RECORD +from app.custom_domain_validation import CustomDomainValidation +from app.db import Session +from app.models import CustomDomain, User +from app.dns_utils import InMemoryDNSClient +from app.utils import random_string +from tests.utils import create_new_user, random_domain + +user: Optional[User] = None + + +def setup_module(): + global user + config.SKIP_MX_LOOKUP_ON_CHECK = True + user = create_new_user() + user.trial_end = None + user.lifetime = True + Session.commit() + + +def create_custom_domain(domain: str) -> CustomDomain: + return CustomDomain.create(user_id=user.id, domain=domain, commit=True) + + +def test_custom_domain_validation_get_dkim_records(): + domain = random_domain() + validator = CustomDomainValidation(domain) + records = validator.get_dkim_records() + + assert len(records) == 3 + assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}" + assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}" + assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}" + + +# validate_dkim_records +def test_custom_domain_validation_validate_dkim_records_empty_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_dkim_records(domain) + + assert len(res) == 3 + for record_value in res.values(): + assert record_value == "empty" + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dkim_verified is False + + +def test_custom_domain_validation_validate_dkim_records_wrong_records_failure(): + dkim_domain = random_domain() + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(dkim_domain, dns_client) + + user_domain = random_domain() + + # One domain right, two domains wrong + dns_client.set_cname_record( + f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}" + ) + dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong") + dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong") + + domain = create_custom_domain(user_domain) + res = validator.validate_dkim_records(domain) + + assert len(res) == 2 + for record_value in res.values(): + assert record_value == "wrong" + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dkim_verified is False + + +def test_custom_domain_validation_validate_dkim_records_success(): + dkim_domain = random_domain() + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(dkim_domain, dns_client) + + user_domain = random_domain() + + # One domain right, two domains wrong + dns_client.set_cname_record( + f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}" + ) + dns_client.set_cname_record( + f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}" + ) + dns_client.set_cname_record( + f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}" + ) + + domain = create_custom_domain(user_domain) + res = validator.validate_dkim_records(domain) + assert len(res) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dkim_verified is True + + +# validate_ownership +def test_custom_domain_validation_validate_ownership_empty_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_domain_ownership(domain) + + assert res.success is False + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.ownership_verified is False + + +def test_custom_domain_validation_validate_ownership_wrong_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + wrong_records = [random_string()] + dns_client.set_txt_record(domain.domain, wrong_records) + res = validator.validate_domain_ownership(domain) + + assert res.success is False + assert res.errors == wrong_records + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.ownership_verified is False + + +def test_custom_domain_validation_validate_ownership_success(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()]) + res = validator.validate_domain_ownership(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.ownership_verified is True + + +# validate_mx_records +def test_custom_domain_validation_validate_mx_records_empty_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_mx_records(domain) + + assert res.success is False + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.verified is False + + +def test_custom_domain_validation_validate_mx_records_wrong_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + wrong_record_1 = random_string() + wrong_record_2 = random_string() + wrong_records = [(10, wrong_record_1), (20, wrong_record_2)] + dns_client.set_mx_records(domain.domain, wrong_records) + res = validator.validate_mx_records(domain) + + assert res.success is False + assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"] + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.verified is False + + +def test_custom_domain_validation_validate_mx_records_success(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY) + res = validator.validate_mx_records(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.verified is True + + +# validate_spf_records +def test_custom_domain_validation_validate_spf_records_empty_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_spf_records(domain) + + assert res.success is False + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.spf_verified is False + + +def test_custom_domain_validation_validate_spf_records_wrong_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + wrong_records = [random_string()] + dns_client.set_txt_record(domain.domain, wrong_records) + res = validator.validate_spf_records(domain) + + assert res.success is False + assert res.errors == wrong_records + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.spf_verified is False + + +def test_custom_domain_validation_validate_spf_records_success(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"]) + res = validator.validate_spf_records(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.spf_verified is True + + +# validate_dmarc_records +def test_custom_domain_validation_validate_dmarc_records_empty_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_dmarc_records(domain) + + assert res.success is False + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dmarc_verified is False + + +def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + wrong_records = [random_string()] + dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records) + res = validator.validate_dmarc_records(domain) + + assert res.success is False + assert res.errors == wrong_records + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dmarc_verified is False + + +def test_custom_domain_validation_validate_dmarc_records_success(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD]) + res = validator.validate_dmarc_records(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dmarc_verified is True diff --git a/tests/test_dns_utils.py b/tests/test_dns_utils.py index d946a9b2a..374983c84 100644 --- a/tests/test_dns_utils.py +++ b/tests/test_dns_utils.py @@ -2,8 +2,11 @@ get_mx_domains, get_network_dns_client, is_mx_equivalent, + InMemoryDNSClient, ) +from tests.utils import random_domain + # use our own domain for test _DOMAIN = "simplelogin.io" @@ -45,3 +48,15 @@ def test_is_mx_equivalent(): [(5, "domain1"), (10, "domain2")], [(10, "domain1"), (20, "domain2"), (20, "domain3")], ) + + +def test_get_spf_record(): + client = InMemoryDNSClient() + + sl_domain = random_domain() + domain = random_domain() + + spf_record = f"v=spf1 include:{sl_domain}" + client.set_txt_record(domain, [spf_record, "another record"]) + res = client.get_spf_domain(domain) + assert res == [sl_domain]