Skip to content

Commit

Permalink
fix: improve MX and SPF domain handling
Browse files Browse the repository at this point in the history
  • Loading branch information
cquintana92 committed Oct 2, 2024
1 parent ed5e62d commit 23fadf0
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 37 deletions.
36 changes: 33 additions & 3 deletions app/custom_domain_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
MxRecord,
DNSClient,
is_mx_equivalent,
get_network_dns_client,
Expand Down Expand Up @@ -43,6 +44,29 @@ def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}-verification={domain.ownership_txt_token}"

def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
records = []
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
domain = self._partner_domains[domain.partner_id]
records.append(MxRecord(10, f"mx1.{domain}."))
records.append(MxRecord(20, f"mx2.{domain}."))
else:
# Default ones
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
records.append(MxRecord(priority, domain))

return records

def get_expected_spf_domain(self, domain: CustomDomain) -> str:
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
return self._partner_domains[domain.partner_id]
else:
return config.EMAIL_DOMAIN

def get_expected_spf_record(self, domain: CustomDomain) -> str:
spf_domain = self.get_expected_spf_domain(domain)
return f"v=spf1 include:{spf_domain} ~all"

def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
Expand Down Expand Up @@ -116,11 +140,12 @@ def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
expected_mx_records = self.get_expected_mx_records(custom_domain)

if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
if not is_mx_equivalent(mx_domains, expected_mx_records):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
errors=[f"{record.priority} {record.domain}" for record in mx_domains],
)
else:
custom_domain.verified = True
Expand All @@ -131,7 +156,8 @@ def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if config.EMAIL_DOMAIN in spf_domains:
expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
Expand All @@ -155,3 +181,7 @@ def validate_dmarc_records(
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)

@staticmethod
def get_instance() -> "CustomDomainValidation":
return CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN)
4 changes: 2 additions & 2 deletions app/dashboard/views/domain_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def domain_detail_dns(custom_domain_id):
custom_domain.ownership_txt_token = random_string(30)
Session.commit()

spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"

domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm()

Expand Down Expand Up @@ -141,7 +139,9 @@ def domain_detail_dns(custom_domain_id):
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),
dkim_records=domain_validator.get_dkim_records(custom_domain),
spf_record=domain_validator.get_expected_spf_record(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(),
)
Expand Down
37 changes: 22 additions & 15 deletions app/dns_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional
from dataclasses import dataclass
from typing import List, Optional

import dns.resolver

Expand All @@ -8,8 +9,14 @@
_include_spf = "include:"


@dataclass
class MxRecord:
priority: int
domain: str


def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
Expand All @@ -18,14 +25,14 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda x: x[0])
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
mx_domains = sorted(mx_domains, key=lambda x: x.priority)
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)

if len(mx_domains) < len(ref_mx_domains):
return False

for i in range(len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]:
for actual, expected in zip(mx_domains, ref_mx_domains):
if actual.domain != expected.domain:
return False

return True
Expand All @@ -37,7 +44,7 @@ def get_cname_record(self, hostname: str) -> Optional[str]:
pass

@abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
pass

def get_spf_domain(self, hostname: str) -> List[str]:
Expand Down Expand Up @@ -81,7 +88,7 @@ def get_cname_record(self, hostname: str) -> Optional[str]:
except Exception:
return None

def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
Expand All @@ -92,8 +99,8 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda x: x[0])
ret.append(MxRecord(priority=int(parts[0]), domain=parts[1]))
return sorted(ret, key=lambda x: x.priority)
except Exception:
return []

Expand All @@ -112,14 +119,14 @@ def get_txt_record(self, hostname: str) -> List[str]:
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
self.mx_records: dict[str, List[MxRecord]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}

def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname

def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
def set_mx_records(self, hostname: str, mx_list: List[MxRecord]):
self.mx_records[hostname] = mx_list

def set_txt_record(self, hostname: str, txt_list: List[str]):
Expand All @@ -128,9 +135,9 @@ def set_txt_record(self, hostname: str, txt_list: List[str]):
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)

def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])
return sorted(mx_list, key=lambda x: x.priority)

def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
Expand All @@ -140,5 +147,5 @@ def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)


def get_mx_domains(hostname: str) -> [(int, str)]:
def get_mx_domains(hostname: str) -> List[MxRecord]:
return get_network_dns_client().get_mx_domains(hostname)
7 changes: 5 additions & 2 deletions cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.dns_utils import get_mx_domains, is_mx_equivalent
from app.email_utils import (
Expand Down Expand Up @@ -905,9 +906,11 @@ def check_custom_domain():
LOG.i("custom domain has been deleted")


def check_single_custom_domain(custom_domain):
def check_single_custom_domain(custom_domain: CustomDomain):
mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
validator = CustomDomainValidation.get_instance()
expected_custom_domains = validator.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_custom_domains):
user = custom_domain.user
LOG.w(
"The MX record is not correctly set for %s %s %s",
Expand Down
8 changes: 5 additions & 3 deletions templates/dashboard/domain_detail/dns.html
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,24 @@ <h1 class="h2">{{ custom_domain.domain }}</h1>
<br />
Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain.
</div>
{% for priority, email_server in EMAIL_SERVERS_WITH_PRIORITY %}

{% for record in expected_mx_records %}

<div class="mb-3 p-3 dns-record">
Record: MX
<br />
Domain: {{ custom_domain.domain }} or
<b>@</b>
<br />
Priority: {{ priority }}
Priority: {{ record.priority }}
<br />
Target: <em data-toggle="tooltip"
title="Click to copy"
class="clipboard"
data-clipboard-text="{{ email_server }}">{{ email_server }}</em>
data-clipboard-text="{{ record.domain }}">{{ record.domain }}</em>
</div>
{% endfor %}

<form method="post" action="#mx-form">
{{ csrf_form.csrf_token }}
<input type="hidden" name="form-name" value="check-mx">
Expand Down
Loading

0 comments on commit 23fadf0

Please sign in to comment.