diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index f8bc90d9f..ce4b844e1 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -186,22 +186,9 @@ def validate_dmarc_records( Session.commit() return DomainValidationResult(success=True, errors=[]) else: - cleaned_records = self.__clean_dmarc_records(txt_records, custom_domain) custom_domain.dmarc_verified = False Session.commit() - return DomainValidationResult(success=False, errors=cleaned_records) - - def __clean_dmarc_records( - self, txt_records: List[str], custom_domain: CustomDomain - ) -> List[str]: - final_records = [] - verification_record = self.get_ownership_verification_record(custom_domain) - spf_record = self.get_expected_spf_record(custom_domain) - for record in txt_records: - if record != verification_record and record != spf_record: - final_records.append(record) - - return final_records + return DomainValidationResult(success=False, errors=txt_records) def __clean_spf_records( self, txt_records: List[str], custom_domain: CustomDomain diff --git a/tests/test_custom_domain_validation.py b/tests/test_custom_domain_validation.py index ccbd40db3..d0de3dbe9 100644 --- a/tests/test_custom_domain_validation.py +++ b/tests/test_custom_domain_validation.py @@ -543,26 +543,3 @@ def test_custom_domain_validation_validate_dmarc_records_success(): db_domain = CustomDomain.get_by(id=domain.id) assert db_domain.dmarc_verified is True - - -def test_custom_domain_validation_validate_dmarc_cleans_verification_and_spf_records(): - dns_client = InMemoryDNSClient() - validator = CustomDomainValidation(random_domain(), dns_client) - - domain = create_custom_domain(random_domain()) - - wrong_record = random_string() - dns_client.set_txt_record( - hostname=f"_dmarc.{domain.domain}", - txt_list=[ - wrong_record, - validator.get_expected_spf_record(domain), - validator.get_ownership_verification_record(domain), - ], - ) - - res = validator.validate_dmarc_records(domain) - - assert res.success is False - assert len(res.errors) == 1 - assert res.errors[0] == wrong_record