Skip to content

Commit

Permalink
Issue #83: Use Asyncio.
Browse files Browse the repository at this point in the history
  • Loading branch information
Nekmo committed Aug 11, 2023
1 parent 138bcb4 commit 9c656b7
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 36 deletions.
16 changes: 15 additions & 1 deletion dirhunt/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import json
import os
from asyncio import Semaphore, Task
from collections import defaultdict
from concurrent.futures.thread import _python_exit
from hashlib import sha256
from threading import Lock, ThreadError
from typing import Optional, Set, Coroutine, Any
from typing import Optional, Set, Coroutine, Any, Dict

import humanize as humanize
from click import get_terminal_size
Expand All @@ -23,6 +24,7 @@
IncompatibleVersionError,
)
from dirhunt.json_report import JsonReportEncoder
from dirhunt.processors import ProcessBase
from dirhunt.sessions import Session
from dirhunt.sources import Sources
from dirhunt.url_info import UrlsInfo
Expand Down Expand Up @@ -75,13 +77,15 @@ def __init__(self, configuration: Configuration, loop: asyncio.AbstractEventLoop
self.start_dt = datetime.datetime.now()
self.current_processed_count: int = 0
self.sources = Sources(self)
self.domain_protocols: Dict[str, set] = defaultdict(set)

async def start(self):
"""Add urls to process."""
for url in self.configuration.urls:
crawler_url = CrawlerUrl(self, url, depth=self.configuration.max_depth)
await self.add_domain(crawler_url.url.domain)
await self.add_crawler_url(crawler_url)
self.add_domain_protocol(crawler_url)

while self.tasks:
await asyncio.wait(self.tasks)
Expand Down Expand Up @@ -125,6 +129,16 @@ def print_error(self, message: str):
text.append(message)
self.console.print(text)

def print_processor(self, processor: ProcessBase):
"""Print processor to console."""
if 300 > processor.status >= 200:
self.add_domain_protocol(processor.crawler_url)
self.console.print(processor.get_text())

def add_domain_protocol(self, crawler_url: "CrawlerUrl"):
"""Add domain protocol"""
self.domain_protocols[crawler_url.url.domain].add(crawler_url.url.protocol)

def add_init_urls(self, *urls):
"""Add urls to queue."""
self.initial_urls.extend(urls)
Expand Down
2 changes: 1 addition & 1 deletion dirhunt/crawler_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ async def retrieve(self):
and not isinstance(processor, GenericProcessor)
and self.url_type not in {"asset", "index_file"}
):
self.crawler.console.print(processor.get_text())
self.crawler.print_processor(processor)
# if self.must_be_downloaded(response):
# processor = get_processor(response, text, self, soup) or GenericProcessor(
# response, self
Expand Down
7 changes: 6 additions & 1 deletion dirhunt/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class ProcessBase:
name = ""
key_name = ""
index_file = None
status: int = 0
status_code = 0 # TODO: rename to status
requires_content = False
# If the processor has descendants, use get_processor after retrieve the content
Expand All @@ -81,6 +82,7 @@ def __init__(self, crawler_url_request: "CrawlerUrlRequest"):
:type crawler_url_request: CrawlerUrlRequest
"""
if crawler_url_request.response is not None:
self.status = crawler_url_request.response.status
self.status_code = crawler_url_request.response.status
# The crawler_url_request takes a lot of memory, so we don't save it
self.crawler_url = crawler_url_request.crawler_url
Expand Down Expand Up @@ -246,7 +248,10 @@ def __init__(self, crawler_url_request: "CrawlerUrlRequest"):

async def process(self, crawler_url_request: "CrawlerUrlRequest") -> None:
"""Process the request. This method will add the redirector url to the crawler."""
if not self.crawler_url.crawler.configuration.not_allow_redirects:
if (
not self.crawler_url.crawler.configuration.not_allow_redirects
and self.redirector
):
await self.add_url(self.redirector)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion dirhunt/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


SOURCE_CLASSES: List[Type["SourceBase"]] = [
# Robots,
Robots,
# VirusTotal,
Google,
CommonCrawl,
Expand Down
58 changes: 26 additions & 32 deletions dirhunt/sources/robots.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,40 @@
from itertools import chain
from typing import Iterable

import requests
from requests import RequestException

from dirhunt.exceptions import SourceError
from dirhunt.sources.base import SourceBase
from dirhunt._compat import RobotFileParser, URLError
from dirhunt._compat import RobotFileParser


PROTOCOLS = ["https", "http"]


def get_url(protocol, domain, path):
path = path.lstrip("/")
return "{protocol}://{domain}/{path}".format(**locals())


class DirhuntRobotFileParser(RobotFileParser):
def read(self):
"""Reads the robots.txt URL and feeds it to the parser."""
try:
with requests.get(self.url) as response:
status_code = response.status_code
text = response.text
except RequestException:
pass
else:
if status_code in (401, 403):
self.disallow_all = True
elif status_code >= 400 and status_code < 500:
self.allow_all = True
self.parse(text.splitlines())


class Robots(SourceBase):
def callback(self, domain, protocol="http"):
rp = DirhuntRobotFileParser()
rp.set_url(get_url(protocol, domain, "robots.txt"))
try:
rp.read()
except (IOError, URLError):
if protocol == "http":
self.callback(domain, "https")
return
async def search_by_domain(self, domain: str) -> Iterable[str]:
if domain not in self.sources.crawler.domain_protocols:
raise SourceError(f"Protocol not available for domain: {domain}")
protocols = self.sources.crawler.domain_protocols[domain]
protocols = filter(lambda x: x in PROTOCOLS, protocols)
protocols = sorted(protocols, key=lambda x: PROTOCOLS.index(x))
protocol = protocols[0]
rp = RobotFileParser()
async with self.sources.crawler.session.get(
get_url(protocol, domain, "robots.txt")
) as response:
if response.status == 404:
return []
response.raise_for_status()
lines = (await response.text()).splitlines()
rp.parse(lines)
entries = list(rp.entries)
if rp.default_entry:
entries.append(rp.default_entry)
for ruleline in chain(*[entry.rulelines for entry in entries]):
self.add_result(get_url(protocol, domain, ruleline.path))
return [
get_url(protocol, domain, ruleline.path)
for ruleline in chain(*[entry.rulelines for entry in entries])
]

0 comments on commit 9c656b7

Please sign in to comment.