Skip to content

Commit

Permalink
feat: extract custom domain utils to a service
Browse files Browse the repository at this point in the history
  • Loading branch information
cquintana92 committed Sep 13, 2024
1 parent 025d4fe commit 0d4b100
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 42 deletions.
45 changes: 45 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ def sl_getenv(env_var: str, default_factory: Callable = None):
return literal_eval(value)


def get_env_dict(env_var: str) -> dict[str, str]:
"""
Get an env variable and convert it into a python dictionary with keys and values as strings.
Args:
env_var (str): env var, example: SL_DB
Syntax is: key1=value1;key2=value2
Components separated by ;
key and value separated by =
"""
value = os.getenv(env_var)
if not value:
return {}

components = value.split(";")
result = {}
for component in components:
if component == "":
continue
parts = component.split("=")
if len(parts) != 2:
raise Exception(f"Invalid config for env var {env_var}")
result[parts[0].strip()] = parts[1].strip()

return result


config_file = os.environ.get("CONFIG")
if config_file:
config_file = get_abs_path(config_file)
Expand Down Expand Up @@ -609,3 +636,21 @@ def read_webhook_enabled_user_ids() -> Optional[List[int]]:
# Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool
# It defaults to the regular DB_URI in case it's needed
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)


def read_partner_domains() -> dict[int, str]:
partner_domains_dict = get_env_dict("PARTNER_DOMAINS")
if len(partner_domains_dict) == 0:
return {}

res: dict[int, str] = {}
for partner_id in partner_domains_dict.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_domains_dict[partner_id]
except ValueError:
pass
return res


PARTNER_DOMAINS: dict[int, str] = read_partner_domains()
1 change: 1 addition & 0 deletions app/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies"
DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
Empty file added app/custom_domain_utils.py
Empty file.
75 changes: 74 additions & 1 deletion app/custom_domain_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import get_cname_record
from app.dns_utils import (
get_cname_record,
get_mx_domains,
get_txt_record,
is_mx_equivalent,
get_spf_domain,
)
from app.models import CustomDomain
from dataclasses import dataclass


@dataclass
class DomainValidationResult:
success: bool
errors: [str]


class CustomDomainValidation:
Expand Down Expand Up @@ -35,3 +50,61 @@ def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit()
return invalid_records

def validate_domain_ownership(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
"""
Check if the custom_domain has added the ownership verification records
"""
txt_records = get_txt_record(custom_domain.domain)

if custom_domain.get_ownership_dns_txt_value() in txt_records:
custom_domain.ownership_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
return DomainValidationResult(success=False, errors=txt_records)

def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = get_mx_domains(custom_domain.domain)

if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
)
else:
custom_domain.verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])

def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.spf_verified = False
Session.commit()
return DomainValidationResult(
success=False, errors=get_txt_record(custom_domain.domain)
)

def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)
63 changes: 22 additions & 41 deletions app/dashboard/views/domain_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@
from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField

from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN
from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp
from app.db import Session
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
is_mx_equivalent,
)
from app.log import LOG
from app.models import (
CustomDomain,
Expand Down Expand Up @@ -49,8 +44,6 @@ def domain_detail_dns(custom_domain_id):
domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm()

dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"

mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True
mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = []

Expand All @@ -59,15 +52,14 @@ def domain_detail_dns(custom_domain_id):
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "check-ownership":
txt_records = get_txt_record(custom_domain.domain)

if custom_domain.get_ownership_dns_txt_value() in txt_records:
ownership_validation_result = domain_validator.validate_domain_ownership(
custom_domain
)
if ownership_validation_result.success:
flash(
"Domain ownership is verified. Please proceed to the other records setup",
"success",
)
custom_domain.ownership_verified = True
Session.commit()
return redirect(
url_for(
"dashboard.domain_detail_dns",
Expand All @@ -78,51 +70,41 @@ def domain_detail_dns(custom_domain_id):
else:
flash("We can't find the needed TXT record", "error")
ownership_ok = False
ownership_errors = txt_records
ownership_errors = ownership_validation_result.errors

elif request.form.get("form-name") == "check-mx":
mx_domains = get_mx_domains(custom_domain.domain)

if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
flash("The MX record is not correctly set", "warning")

mx_ok = False
# build mx_errors to show to user
mx_errors = [
f"{priority} {domain}" for (priority, domain) in mx_domains
]
else:
mx_validation_result = domain_validator.validate_mx_records(custom_domain)
if mx_validation_result.success:
flash(
"Your domain can start receiving emails. You can now use it to create alias",
"success",
)
custom_domain.verified = True
Session.commit()
return redirect(
url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
)
)
else:
flash("The MX record is not correctly set", "warning")
mx_ok = False
mx_errors = mx_validation_result.errors

elif request.form.get("form-name") == "check-spf":
spf_domains = get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
spf_validation_result = domain_validator.validate_spf_records(custom_domain)
if spf_validation_result.success:
flash("SPF is setup correctly", "success")
return redirect(
url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
)
)
else:
custom_domain.spf_verified = False
Session.commit()
flash(
f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.",
"warning",
)
spf_ok = False
spf_errors = get_txt_record(custom_domain.domain)
spf_errors = spf_validation_result.errors

elif request.form.get("form-name") == "check-dkim":
dkim_errors = domain_validator.validate_dkim_records(custom_domain)
Expand All @@ -138,30 +120,29 @@ def domain_detail_dns(custom_domain_id):
flash("DKIM: the CNAME record is not correctly set", "warning")

elif request.form.get("form-name") == "check-dmarc":
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
if dmarc_record in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
dmarc_validation_result = domain_validator.validate_dmarc_records(
custom_domain
)
if dmarc_validation_result.success:
flash("DMARC is setup correctly", "success")
return redirect(
url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
)
)
else:
custom_domain.dmarc_verified = False
Session.commit()
flash(
"DMARC: The TXT record is not correctly set",
"warning",
)
dmarc_ok = False
dmarc_errors = txt_records
dmarc_errors = dmarc_validation_result.errors

return render_template(
"dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(),
dmarc_record=DMARC_RECORD,
**locals(),
)

Expand Down

0 comments on commit 0d4b100

Please sign in to comment.