Skip to content

Commit

Permalink
chore: DNS validation improvements (#2248)
Browse files Browse the repository at this point in the history
* chore: DNS validation improvements

* fix: do not show domains pending deletion

* fix: generate verification token if null

* revert: dmarc cleanup
  • Loading branch information
cquintana92 authored Oct 3, 2024
1 parent 06ab116 commit 9d5697b
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 4 deletions.
22 changes: 20 additions & 2 deletions app/custom_domain_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

from app import config
from app.constants import DMARC_RECORD
Expand All @@ -11,6 +11,7 @@
get_network_dns_client,
)
from app.models import CustomDomain
from app.utils import random_string


@dataclass
Expand Down Expand Up @@ -42,6 +43,11 @@ def get_ownership_verification_record(self, domain: CustomDomain) -> str:
and domain.partner_id in self._partner_domain_validation_prefixes
):
prefix = self._partner_domain_validation_prefixes[domain.partner_id]

if not domain.ownership_txt_token:
domain.ownership_txt_token = random_string(30)
Session.commit()

return f"{prefix}-verification={domain.ownership_txt_token}"

def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
Expand Down Expand Up @@ -164,9 +170,11 @@ def validate_spf_records(
else:
custom_domain.spf_verified = False
Session.commit()
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
cleaned_records = self.__clean_spf_records(txt_records, custom_domain)
return DomainValidationResult(
success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain),
errors=cleaned_records,
)

def validate_dmarc_records(
Expand All @@ -181,3 +189,13 @@ def validate_dmarc_records(
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)

def __clean_spf_records(
self, txt_records: List[str], custom_domain: CustomDomain
) -> List[str]:
final_records = []
verification_record = self.get_ownership_verification_record(custom_domain)
for record in txt_records:
if record != verification_record:
final_records.append(record)
return final_records
4 changes: 3 additions & 1 deletion app/dashboard/views/custom_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class NewCustomDomainForm(FlaskForm):
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def custom_domain():
custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False
user_id=current_user.id,
is_sl_subdomain=False,
pending_deletion=False,
).all()
new_custom_domain_form = NewCustomDomainForm()

Expand Down
2 changes: 1 addition & 1 deletion app/dns_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_mx_domains(self, hostname: str) -> List[MxRecord]:

def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
answers = self._resolver.resolve(hostname, "TXT", search=False)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_custom_domain_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,33 @@ def test_custom_domain_validation_validate_spf_records_partner_domain_success():
assert db_domain.spf_verified is True


def test_custom_domain_validation_validate_spf_cleans_verification_record():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id

expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)

domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()

wrong_record = random_string()
dns_client.set_txt_record(
hostname=domain.domain,
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
)
res = validator.validate_spf_records(domain)

assert res.success is False
assert len(res.errors) == 1
assert res.errors[0] == wrong_record


# validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient()
Expand Down

0 comments on commit 9d5697b

Please sign in to comment.