Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extract custom domain utils to a service #2215

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ def sl_getenv(env_var: str, default_factory: Callable = None):
return literal_eval(value)


def get_env_dict(env_var: str) -> dict[str, str]:
"""
Get an env variable and convert it into a python dictionary with keys and values as strings.
Args:
env_var (str): env var, example: SL_DB

Syntax is: key1=value1;key2=value2
Components separated by ;
key and value separated by =
"""
value = os.getenv(env_var)
if not value:
return {}

components = value.split(";")
result = {}
for component in components:
if component == "":
continue
parts = component.split("=")
if len(parts) != 2:
raise Exception(f"Invalid config for env var {env_var}")
result[parts[0].strip()] = parts[1].strip()

return result


config_file = os.environ.get("CONFIG")
if config_file:
config_file = get_abs_path(config_file)
Expand Down Expand Up @@ -609,3 +636,21 @@ def read_webhook_enabled_user_ids() -> Optional[List[int]]:
# Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool
# It defaults to the regular DB_URI in case it's needed
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)


def read_partner_domains() -> dict[int, str]:
partner_domains_dict = get_env_dict("PARTNER_DOMAINS")
if len(partner_domains_dict) == 0:
return {}

res: dict[int, str] = {}
for partner_id in partner_domains_dict.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_domains_dict[partner_id]
except ValueError:
pass
return res


PARTNER_DOMAINS: dict[int, str] = read_partner_domains()
1 change: 1 addition & 0 deletions app/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies"
DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
Empty file added app/custom_domain_utils.py
Empty file.
75 changes: 74 additions & 1 deletion app/custom_domain_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import get_cname_record
from app.dns_utils import (
get_cname_record,
get_mx_domains,
get_txt_record,
is_mx_equivalent,
get_spf_domain,
)
from app.models import CustomDomain
from dataclasses import dataclass


@dataclass
class DomainValidationResult:
success: bool
errors: [str]


class CustomDomainValidation:
Expand Down Expand Up @@ -35,3 +50,61 @@ def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit()
return invalid_records

def validate_domain_ownership(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
"""
Check if the custom_domain has added the ownership verification records
"""
txt_records = get_txt_record(custom_domain.domain)

if custom_domain.get_ownership_dns_txt_value() in txt_records:
custom_domain.ownership_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
return DomainValidationResult(success=False, errors=txt_records)

def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = get_mx_domains(custom_domain.domain)

if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
)
else:
custom_domain.verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])

def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.spf_verified = False
Session.commit()
return DomainValidationResult(
success=False, errors=get_txt_record(custom_domain.domain)
)

def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)
63 changes: 22 additions & 41 deletions app/dashboard/views/domain_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@
from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField

from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN
from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp
from app.db import Session
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
is_mx_equivalent,
)
from app.log import LOG
from app.models import (
CustomDomain,
Expand Down Expand Up @@ -49,8 +44,6 @@ def domain_detail_dns(custom_domain_id):
domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm()

dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"

mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True
mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = []

Expand All @@ -59,15 +52,14 @@ def domain_detail_dns(custom_domain_id):
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "check-ownership":
txt_records = get_txt_record(custom_domain.domain)

if custom_domain.get_ownership_dns_txt_value() in txt_records:
ownership_validation_result = domain_validator.validate_domain_ownership(
custom_domain
)
if ownership_validation_result.success:
flash(
"Domain ownership is verified. Please proceed to the other records setup",
"success",
)
custom_domain.ownership_verified = True
Session.commit()
return redirect(
url_for(
"dashboard.domain_detail_dns",
Expand All @@ -78,51 +70,41 @@ def domain_detail_dns(custom_domain_id):
else:
flash("We can't find the needed TXT record", "error")
ownership_ok = False
ownership_errors = txt_records
ownership_errors = ownership_validation_result.errors

elif request.form.get("form-name") == "check-mx":
mx_domains = get_mx_domains(custom_domain.domain)

if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
flash("The MX record is not correctly set", "warning")

mx_ok = False
# build mx_errors to show to user
mx_errors = [
f"{priority} {domain}" for (priority, domain) in mx_domains
]
else:
mx_validation_result = domain_validator.validate_mx_records(custom_domain)
if mx_validation_result.success:
flash(
"Your domain can start receiving emails. You can now use it to create alias",
"success",
)
custom_domain.verified = True
Session.commit()
return redirect(
url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
)
)
else:
flash("The MX record is not correctly set", "warning")
mx_ok = False
mx_errors = mx_validation_result.errors

elif request.form.get("form-name") == "check-spf":
spf_domains = get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
spf_validation_result = domain_validator.validate_spf_records(custom_domain)
if spf_validation_result.success:
flash("SPF is setup correctly", "success")
return redirect(
url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
)
)
else:
custom_domain.spf_verified = False
Session.commit()
flash(
f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.",
"warning",
)
spf_ok = False
spf_errors = get_txt_record(custom_domain.domain)
spf_errors = spf_validation_result.errors

elif request.form.get("form-name") == "check-dkim":
dkim_errors = domain_validator.validate_dkim_records(custom_domain)
Expand All @@ -138,30 +120,29 @@ def domain_detail_dns(custom_domain_id):
flash("DKIM: the CNAME record is not correctly set", "warning")

elif request.form.get("form-name") == "check-dmarc":
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
if dmarc_record in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
dmarc_validation_result = domain_validator.validate_dmarc_records(
custom_domain
)
if dmarc_validation_result.success:
flash("DMARC is setup correctly", "success")
return redirect(
url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
)
)
else:
custom_domain.dmarc_verified = False
Session.commit()
flash(
"DMARC: The TXT record is not correctly set",
"warning",
)
dmarc_ok = False
dmarc_errors = txt_records
dmarc_errors = dmarc_validation_result.errors

return render_template(
"dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(),
dmarc_record=DMARC_RECORD,
**locals(),
)

Expand Down
Loading