Skip to content

Commit

Permalink
fix: improve MX and SPF domain handling (#2246)
Browse files Browse the repository at this point in the history
* fix: improve MX and SPF domain handling

* fix: do not use get_instance function

* fix: legacy use of tuples instead of MxRecord

* refactor: rename partner custom domain variables
  • Loading branch information
cquintana92 authored Oct 2, 2024
1 parent ed5e62d commit b97a1dd
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 45 deletions.
8 changes: 5 additions & 3 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,11 @@ def read_partner_dict(var: str) -> dict[int, str]:
return res


PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
PARTNER_DNS_CUSTOM_DOMAINS: dict[int, str] = read_partner_dict(
"PARTNER_DNS_CUSTOM_DOMAINS"
)
PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES"
)

MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get(
Expand Down
36 changes: 31 additions & 5 deletions app/custom_domain_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
MxRecord,
DNSClient,
is_mx_equivalent,
get_network_dns_client,
Expand All @@ -28,10 +29,10 @@ def __init__(
):
self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._partner_domains = partner_domains or config.PARTNER_DOMAINS
self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes
or config.PARTNER_DOMAIN_VALIDATION_PREFIXES
or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
)

def get_ownership_verification_record(self, domain: CustomDomain) -> str:
Expand All @@ -43,6 +44,29 @@ def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}-verification={domain.ownership_txt_token}"

def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
records = []
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
domain = self._partner_domains[domain.partner_id]
records.append(MxRecord(10, f"mx1.{domain}."))
records.append(MxRecord(20, f"mx2.{domain}."))
else:
# Default ones
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
records.append(MxRecord(priority, domain))

return records

def get_expected_spf_domain(self, domain: CustomDomain) -> str:
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
return self._partner_domains[domain.partner_id]
else:
return config.EMAIL_DOMAIN

def get_expected_spf_record(self, domain: CustomDomain) -> str:
spf_domain = self.get_expected_spf_domain(domain)
return f"v=spf1 include:{spf_domain} ~all"

def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
Expand Down Expand Up @@ -116,11 +140,12 @@ def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
expected_mx_records = self.get_expected_mx_records(custom_domain)

if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
if not is_mx_equivalent(mx_domains, expected_mx_records):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
errors=[f"{record.priority} {record.domain}" for record in mx_domains],
)
else:
custom_domain.verified = True
Expand All @@ -131,7 +156,8 @@ def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if config.EMAIL_DOMAIN in spf_domains:
expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
Expand Down
4 changes: 2 additions & 2 deletions app/dashboard/views/domain_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def domain_detail_dns(custom_domain_id):
custom_domain.ownership_txt_token = random_string(30)
Session.commit()

spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"

domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm()

Expand Down Expand Up @@ -141,7 +139,9 @@ def domain_detail_dns(custom_domain_id):
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),
dkim_records=domain_validator.get_dkim_records(custom_domain),
spf_record=domain_validator.get_expected_spf_record(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(),
)
Expand Down
37 changes: 22 additions & 15 deletions app/dns_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional
from dataclasses import dataclass
from typing import List, Optional

import dns.resolver

Expand All @@ -8,8 +9,14 @@
_include_spf = "include:"


@dataclass
class MxRecord:
priority: int
domain: str


def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
Expand All @@ -18,14 +25,14 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda x: x[0])
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
mx_domains = sorted(mx_domains, key=lambda x: x.priority)
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)

if len(mx_domains) < len(ref_mx_domains):
return False

for i in range(len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]:
for actual, expected in zip(mx_domains, ref_mx_domains):
if actual.domain != expected.domain:
return False

return True
Expand All @@ -37,7 +44,7 @@ def get_cname_record(self, hostname: str) -> Optional[str]:
pass

@abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
pass

def get_spf_domain(self, hostname: str) -> List[str]:
Expand Down Expand Up @@ -81,7 +88,7 @@ def get_cname_record(self, hostname: str) -> Optional[str]:
except Exception:
return None

def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
Expand All @@ -92,8 +99,8 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda x: x[0])
ret.append(MxRecord(priority=int(parts[0]), domain=parts[1]))
return sorted(ret, key=lambda x: x.priority)
except Exception:
return []

Expand All @@ -112,14 +119,14 @@ def get_txt_record(self, hostname: str) -> List[str]:
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
self.mx_records: dict[str, List[MxRecord]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}

def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname

def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
def set_mx_records(self, hostname: str, mx_list: List[MxRecord]):
self.mx_records[hostname] = mx_list

def set_txt_record(self, hostname: str, txt_list: List[str]):
Expand All @@ -128,9 +135,9 @@ def set_txt_record(self, hostname: str, txt_list: List[str]):
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)

def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])
return sorted(mx_list, key=lambda x: x.priority)

def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
Expand All @@ -140,5 +147,5 @@ def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)


def get_mx_domains(hostname: str) -> [(int, str)]:
def get_mx_domains(hostname: str) -> List[MxRecord]:
return get_network_dns_client().get_mx_domains(hostname)
2 changes: 1 addition & 1 deletion app/email_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def get_mx_domain_list(domain) -> [str]:
"""
priority_domains = get_mx_domains(domain)

return [d[:-1] for _, d in priority_domains]
return [d.domain[:-1] for d in priority_domains]


def personal_email_already_used(email_address: str) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,9 +2766,9 @@ def is_proton(self) -> bool:

from app.email_utils import get_email_local_part

mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email))
mx_domains = get_mx_domains(get_email_local_part(self.email))
# Proton is the first domain
if mx_domains and mx_domains[0][1] in (
if mx_domains and mx_domains[0].domain in (
"mail.protonmail.ch.",
"mailsec.protonmail.ch.",
):
Expand Down
7 changes: 5 additions & 2 deletions cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.dns_utils import get_mx_domains, is_mx_equivalent
from app.email_utils import (
Expand Down Expand Up @@ -905,9 +906,11 @@ def check_custom_domain():
LOG.i("custom domain has been deleted")


def check_single_custom_domain(custom_domain):
def check_single_custom_domain(custom_domain: CustomDomain):
mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
validator = CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN)
expected_custom_domains = validator.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_custom_domains):
user = custom_domain.user
LOG.w(
"The MX record is not correctly set for %s %s %s",
Expand Down
8 changes: 5 additions & 3 deletions templates/dashboard/domain_detail/dns.html
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,24 @@ <h1 class="h2">{{ custom_domain.domain }}</h1>
<br />
Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain.
</div>
{% for priority, email_server in EMAIL_SERVERS_WITH_PRIORITY %}

{% for record in expected_mx_records %}

<div class="mb-3 p-3 dns-record">
Record: MX
<br />
Domain: {{ custom_domain.domain }} or
<b>@</b>
<br />
Priority: {{ priority }}
Priority: {{ record.priority }}
<br />
Target: <em data-toggle="tooltip"
title="Click to copy"
class="clipboard"
data-clipboard-text="{{ email_server }}">{{ email_server }}</em>
data-clipboard-text="{{ record.domain }}">{{ record.domain }}</em>
</div>
{% endfor %}

<form method="post" action="#mx-form">
{{ csrf_form.csrf_token }}
<input type="hidden" name="form-name" value="check-mx">
Expand Down
Loading

0 comments on commit b97a1dd

Please sign in to comment.