Skip to content

Commit

Permalink
feat: allow to define partner records (#2225)
Browse files Browse the repository at this point in the history
  • Loading branch information
cquintana92 authored Sep 18, 2024
1 parent f6708dd commit b5866fa
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 26 deletions.
15 changes: 9 additions & 6 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,19 +638,22 @@ def read_webhook_enabled_user_ids() -> Optional[List[int]]:
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)


def read_partner_domains() -> dict[int, str]:
partner_domains_dict = get_env_dict("PARTNER_DOMAINS")
if len(partner_domains_dict) == 0:
def read_partner_dict(var: str) -> dict[int, str]:
partner_value = get_env_dict(var)
if len(partner_value) == 0:
return {}

res: dict[int, str] = {}
for partner_id in partner_domains_dict.keys():
for partner_id in partner_value.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_domains_dict[partner_id]
res[partner_id_int] = partner_value[partner_id]
except ValueError:
pass
return res


PARTNER_DOMAINS: dict[int, str] = read_partner_domains()
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
)
61 changes: 48 additions & 13 deletions app/custom_domain_validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from dataclasses import dataclass
from typing import Optional

from app.config import (
EMAIL_SERVERS_WITH_PRIORITY,
EMAIL_DOMAIN,
PARTNER_DOMAINS,
PARTNER_DOMAIN_VALIDATION_PREFIXES,
)
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
Expand All @@ -7,7 +15,6 @@
get_network_dns_client,
)
from app.models import CustomDomain
from dataclasses import dataclass


@dataclass
Expand All @@ -18,21 +25,46 @@ class DomainValidationResult:

class CustomDomainValidation:
def __init__(
self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
self,
dkim_domain: str,
dns_client: DNSClient = get_network_dns_client(),
partner_domains: Optional[dict[int, str]] = None,
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
):
self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._dkim_records = {
f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}
self._partner_domains = partner_domains or PARTNER_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes or PARTNER_DOMAIN_VALIDATION_PREFIXES
)

def get_dkim_records(self) -> {str: str}:
"""
Get a list of dkim records to set up. It will be
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = "sl-verification"
if (
domain.partner_id is not None
and domain.partner_id in self._partner_domain_validation_prefixes
):
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}={domain.ownership_txt_token}"

def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
it will return the default ones or the partner ones.
"""
return self._dkim_records

# By default use the default domain
dkim_domain = self.dkim_domain
if domain.partner_id is not None:
# Domain is from a partner. Retrieve the partner config and use that domain if exists
partner_domain = self._partner_domains.get(domain.partner_id)
if partner_domain is not None:
dkim_domain = partner_domain

return {
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}

def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
"""
Expand All @@ -41,7 +73,7 @@ def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
"""
correct_records = {}
invalid_records = {}
expected_records = self.get_dkim_records()
expected_records = self.get_dkim_records(custom_domain)
for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = self._dns_client.get_cname_record(custom_record)
Expand Down Expand Up @@ -75,8 +107,11 @@ def validate_domain_ownership(
Check if the custom_domain has added the ownership verification records
"""
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
expected_verification_record = self.get_ownership_verification_record(
custom_domain
)

if custom_domain.get_ownership_dns_txt_value() in txt_records:
if expected_verification_record in txt_records:
custom_domain.ownership_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
Expand Down
5 changes: 4 additions & 1 deletion app/dashboard/views/domain_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def domain_detail_dns(custom_domain_id):
return render_template(
"dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(),
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
dkim_records=domain_validator.get_dkim_records(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(),
)
Expand Down
3 changes: 0 additions & 3 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2451,9 +2451,6 @@ def nb_alias(self):
def get_trash_url(self):
return config.URL + f"/dashboard/domains/{self.id}/trash"

def get_ownership_dns_txt_value(self):
return f"sl-verification={self.ownership_txt_token}"

@classmethod
def create(cls, **kwargs):
domain = kwargs.get("domain")
Expand Down
2 changes: 1 addition & 1 deletion templates/dashboard/domain_detail/dns.html
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ <h1 class="h2">{{ custom_domain.domain }}</h1>
Value: <em data-toggle="tooltip"
title="Click to copy"
class="clipboard"
data-clipboard-text="{{ custom_domain.get_ownership_dns_txt_value() }}">{{ custom_domain.get_ownership_dns_txt_value() }}</em>
data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em>
</div>
<form method="post" action="#ownership-form">
{{ csrf_form.csrf_token }}
Expand Down
55 changes: 53 additions & 2 deletions tests/test_custom_domain_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from app.db import Session
from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient
from app.proton.utils import get_proton_partner
from app.utils import random_string
from tests.utils import create_new_user, random_domain

Expand All @@ -27,15 +28,36 @@ def create_custom_domain(domain: str) -> CustomDomain:

def test_custom_domain_validation_get_dkim_records():
domain = random_domain()
custom_domain = create_custom_domain(domain)
validator = CustomDomainValidation(domain)
records = validator.get_dkim_records()
records = validator.get_dkim_records(custom_domain)

assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"


def test_custom_domain_validation_get_dkim_records_for_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)

partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()

dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_dkim_records(custom_domain)

assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"


# validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient()
Expand Down Expand Up @@ -169,7 +191,36 @@ def test_custom_domain_validation_validate_ownership_success():

domain = create_custom_domain(random_domain())

dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)

assert res.success is True
assert len(res.errors) == 0

db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True


def test_custom_domain_validation_validate_ownership_from_partner_success():
dns_client = InMemoryDNSClient()
partner_id = get_proton_partner().id

prefix = random_string()
validator = CustomDomainValidation(
random_domain(),
dns_client,
partner_domains_validation_prefixes={partner_id: prefix},
)

domain = create_custom_domain(random_domain())
domain.partner_id = partner_id
Session.commit()

dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)

assert res.success is True
Expand Down

0 comments on commit b5866fa

Please sign in to comment.