diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py
index 58b698b13..e9b26825e 100644
--- a/app/custom_domain_validation.py
+++ b/app/custom_domain_validation.py
@@ -23,7 +23,7 @@ def __init__(
self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._dkim_records = {
- (f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
+ f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}
@@ -40,7 +40,7 @@ def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
"""
invalid_records = {}
- for prefix, expected_record in self.get_dkim_records():
+ for prefix, expected_record in self.get_dkim_records().items():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record != expected_record:
diff --git a/app/dns_utils.py b/app/dns_utils.py
index d55fd6f4f..2ce699340 100644
--- a/app/dns_utils.py
+++ b/app/dns_utils.py
@@ -40,9 +40,24 @@ def get_cname_record(self, hostname: str) -> Optional[str]:
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
pass
- @abstractmethod
def get_spf_domain(self, hostname: str) -> List[str]:
- pass
+ """
+ return all domains listed in *include:*
+ """
+ try:
+ records = self.get_txt_record(hostname)
+ ret = []
+ for record in records:
+ if record.startswith("v=spf1"):
+ parts = record.split(" ")
+ for part in parts:
+ if part.startswith(_include_spf):
+ ret.append(
+ part[part.find(_include_spf) + len(_include_spf) :]
+ )
+ return ret
+ except Exception:
+ return []
@abstractmethod
def get_txt_record(self, hostname: str) -> List[str]:
@@ -82,27 +97,6 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
except Exception:
return []
- def get_spf_domain(self, hostname: str) -> List[str]:
- """
- return all domains listed in *include:*
- """
- try:
- answers = self._resolver.resolve(hostname, "TXT", search=True)
- ret = []
- for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
- for record in a.strings:
- record_str = record.decode() # record is bytes
- if record_str.startswith("v=spf1"):
- parts = record_str.split(" ")
- for part in parts:
- if part.startswith(_include_spf):
- ret.append(
- part[part.find(_include_spf) + len(_include_spf) :]
- )
- return ret
- except Exception:
- return []
-
def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
@@ -128,9 +122,6 @@ def set_cname_record(self, hostname: str, cname: str):
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
self.mx_records[hostname] = mx_list
- def set_spf_domain(self, hostname: str, spf_list: List[str]):
- self.spf_records[hostname] = spf_list
-
def set_txt_record(self, hostname: str, txt_list: List[str]):
self.txt_records[hostname] = txt_list
@@ -141,9 +132,6 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])
- def get_spf_domain(self, hostname: str) -> List[str]:
- return self.spf_records.get(hostname, [])
-
def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html
index 15ef346f7..810aa302a 100644
--- a/templates/dashboard/domain_detail/dns.html
+++ b/templates/dashboard/domain_detail/dns.html
@@ -237,7 +237,7 @@
{{ custom_domain.domain }}
folder.
Add the following CNAME DNS records to your domain.
- {% for dkim_prefix, dkim_cname_value in dkim_records %}
+ {% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
Record: CNAME
diff --git a/tests/test_custom_domain_validation.py b/tests/test_custom_domain_validation.py
new file mode 100644
index 000000000..0a5cafd75
--- /dev/null
+++ b/tests/test_custom_domain_validation.py
@@ -0,0 +1,297 @@
+from typing import Optional
+
+from app import config
+from app.constants import DMARC_RECORD
+from app.custom_domain_validation import CustomDomainValidation
+from app.db import Session
+from app.models import CustomDomain, User
+from app.dns_utils import InMemoryDNSClient
+from app.utils import random_string
+from tests.utils import create_new_user, random_domain
+
+user: Optional[User] = None
+
+
+def setup_module():
+ global user
+ config.SKIP_MX_LOOKUP_ON_CHECK = True
+ user = create_new_user()
+ user.trial_end = None
+ user.lifetime = True
+ Session.commit()
+
+
+def create_custom_domain(domain: str) -> CustomDomain:
+ return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
+
+
+def test_custom_domain_validation_get_dkim_records():
+ domain = random_domain()
+ validator = CustomDomainValidation(domain)
+ records = validator.get_dkim_records()
+
+ assert len(records) == 3
+ assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
+ assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
+ assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
+
+
+# validate_dkim_records
+def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_dkim_records(domain)
+
+ assert len(res) == 3
+ for record_value in res.values():
+ assert record_value == "empty"
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dkim_verified is False
+
+
+def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
+ dkim_domain = random_domain()
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(dkim_domain, dns_client)
+
+ user_domain = random_domain()
+
+ # One domain right, two domains wrong
+ dns_client.set_cname_record(
+ f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
+ )
+ dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
+ dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
+
+ domain = create_custom_domain(user_domain)
+ res = validator.validate_dkim_records(domain)
+
+ assert len(res) == 2
+ for record_value in res.values():
+ assert record_value == "wrong"
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dkim_verified is False
+
+
+def test_custom_domain_validation_validate_dkim_records_success():
+ dkim_domain = random_domain()
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(dkim_domain, dns_client)
+
+ user_domain = random_domain()
+
+ # One domain right, two domains wrong
+ dns_client.set_cname_record(
+ f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
+ )
+ dns_client.set_cname_record(
+ f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
+ )
+ dns_client.set_cname_record(
+ f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
+ )
+
+ domain = create_custom_domain(user_domain)
+ res = validator.validate_dkim_records(domain)
+ assert len(res) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dkim_verified is True
+
+
+# validate_ownership
+def test_custom_domain_validation_validate_ownership_empty_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_domain_ownership(domain)
+
+ assert res.success is False
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.ownership_verified is False
+
+
+def test_custom_domain_validation_validate_ownership_wrong_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ wrong_records = [random_string()]
+ dns_client.set_txt_record(domain.domain, wrong_records)
+ res = validator.validate_domain_ownership(domain)
+
+ assert res.success is False
+ assert res.errors == wrong_records
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.ownership_verified is False
+
+
+def test_custom_domain_validation_validate_ownership_success():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
+ res = validator.validate_domain_ownership(domain)
+
+ assert res.success is True
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.ownership_verified is True
+
+
+# validate_mx_records
+def test_custom_domain_validation_validate_mx_records_empty_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_mx_records(domain)
+
+ assert res.success is False
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.verified is False
+
+
+def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ wrong_record_1 = random_string()
+ wrong_record_2 = random_string()
+ wrong_records = [(10, wrong_record_1), (20, wrong_record_2)]
+ dns_client.set_mx_records(domain.domain, wrong_records)
+ res = validator.validate_mx_records(domain)
+
+ assert res.success is False
+ assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.verified is False
+
+
+def test_custom_domain_validation_validate_mx_records_success():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY)
+ res = validator.validate_mx_records(domain)
+
+ assert res.success is True
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.verified is True
+
+
+# validate_spf_records
+def test_custom_domain_validation_validate_spf_records_empty_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_spf_records(domain)
+
+ assert res.success is False
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.spf_verified is False
+
+
+def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ wrong_records = [random_string()]
+ dns_client.set_txt_record(domain.domain, wrong_records)
+ res = validator.validate_spf_records(domain)
+
+ assert res.success is False
+ assert res.errors == wrong_records
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.spf_verified is False
+
+
+def test_custom_domain_validation_validate_spf_records_success():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
+ res = validator.validate_spf_records(domain)
+
+ assert res.success is True
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.spf_verified is True
+
+
+# validate_dmarc_records
+def test_custom_domain_validation_validate_dmarc_records_empty_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_dmarc_records(domain)
+
+ assert res.success is False
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dmarc_verified is False
+
+
+def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ wrong_records = [random_string()]
+ dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
+ res = validator.validate_dmarc_records(domain)
+
+ assert res.success is False
+ assert res.errors == wrong_records
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dmarc_verified is False
+
+
+def test_custom_domain_validation_validate_dmarc_records_success():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
+ res = validator.validate_dmarc_records(domain)
+
+ assert res.success is True
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dmarc_verified is True
diff --git a/tests/test_dns_utils.py b/tests/test_dns_utils.py
index d946a9b2a..374983c84 100644
--- a/tests/test_dns_utils.py
+++ b/tests/test_dns_utils.py
@@ -2,8 +2,11 @@
get_mx_domains,
get_network_dns_client,
is_mx_equivalent,
+ InMemoryDNSClient,
)
+from tests.utils import random_domain
+
# use our own domain for test
_DOMAIN = "simplelogin.io"
@@ -45,3 +48,15 @@ def test_is_mx_equivalent():
[(5, "domain1"), (10, "domain2")],
[(10, "domain1"), (20, "domain2"), (20, "domain3")],
)
+
+
+def test_get_spf_record():
+ client = InMemoryDNSClient()
+
+ sl_domain = random_domain()
+ domain = random_domain()
+
+ spf_record = f"v=spf1 include:{sl_domain}"
+ client.set_txt_record(domain, [spf_record, "another record"])
+ res = client.get_spf_domain(domain)
+ assert res == [sl_domain]