From aa13836d0e5a188395d51a8fa9b9d99bf1e97972 Mon Sep 17 00:00:00 2001 From: Matt Keeley Date: Sat, 10 Aug 2024 23:08:05 -0700 Subject: [PATCH] fix redirect bug --- list.txt | 1 + modules/spf.py | 29 +++++++++++++++++++++++------ 2 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 list.txt diff --git a/list.txt b/list.txt new file mode 100644 index 0000000..140dbdc --- /dev/null +++ b/list.txt @@ -0,0 +1 @@ +reddit.com diff --git a/modules/spf.py b/modules/spf.py index 1c07650..ea9e155 100644 --- a/modules/spf.py +++ b/modules/spf.py @@ -17,12 +17,14 @@ def __init__(self, domain, dns_server=None): self.num_includes = self.get_spf_includes() self.too_many_includes = self.num_includes > 10 - def get_spf_record(self): - """Returns the SPF record for a given domain.""" + def get_spf_record(self, domain=None): + """Fetches the SPF record for the specified domain.""" try: + if not domain: + domain = self.domain resolver = dns.resolver.Resolver() resolver.nameservers = [self.dns_server, '1.1.1.1', '8.8.8.8'] - query_result = resolver.resolve(self.domain, 'TXT') + query_result = resolver.resolve(domain, 'TXT') for record in query_result: if 'spf1' in str(record): spf_record = str(record).replace('"', '') @@ -33,12 +35,27 @@ def get_spf_record(self): def get_spf_all_string(self): """Returns the string value of the 'all' mechanism in the SPF record.""" - if self.spf_record: - all_matches = re.findall(r'[-~?+]all', self.spf_record) + + spf_record = self.spf_record + visited_domains = set() + + while spf_record: + all_matches = re.findall(r'[-~?+]all', spf_record) if len(all_matches) == 1: return all_matches[0] elif len(all_matches) > 1: return '2many' + + redirect_match = re.search(r'redirect=([\w.-]+)', spf_record) + if redirect_match: + redirect_domain = redirect_match.group(1) + if redirect_domain in visited_domains: + break # Prevent infinite loops in case of circular redirects + visited_domains.add(redirect_domain) + spf_record = self.get_spf_record(redirect_domain) + else: + break + return None def get_spf_includes(self): @@ -57,7 +74,7 @@ def count_includes(spf_record): txt_record = txt_string.decode('utf-8') if txt_record.startswith('v=spf1'): count += count_includes(txt_record) - except Exception as e: + except Exception: pass # Count occurrences of 'a', 'mx', 'ptr', and 'exists' mechanisms