diff --git a/shadowmire.py b/shadowmire.py index ef8bf1b..843cb9b 100755 --- a/shadowmire.py +++ b/shadowmire.py @@ -2,12 +2,12 @@ import sys from types import FrameType -from typing import IO, Any, Callable, Generator, Optional +from typing import IO, Any, Callable, Generator, Literal, NoReturn, Optional import xmlrpc.client from dataclasses import dataclass import re import json -from urllib.parse import urljoin, urlparse, urlunparse +from urllib.parse import urljoin, urlparse, urlunparse, unquote from pathlib import Path from html.parser import HTMLParser import logging @@ -18,10 +18,12 @@ ) # fast path computation, instead of accessing real files like pathlib from contextlib import contextmanager import sqlite3 -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import Future, ThreadPoolExecutor, as_completed import signal import tomllib from copy import deepcopy +import functools +from http.client import HTTPConnection import requests import click @@ -36,6 +38,10 @@ # Note that it's suggested to use only 3 workers for PyPI. WORKERS = int(os.environ.get("SHADOWMIRE_WORKERS", "3")) +# Use threads to parallelize verification local IO +IOWORKERS = int(os.environ.get("SHADOWMIRE_IOWORKERS", "2")) +# A safety net -- to avoid upstream issues casuing too many packages removed when determinating sync plan. +MAX_DELETION = int(os.environ.get("SHADOWMIRE_MAX_DELETION", "50000")) # https://github.com/pypa/bandersnatch/blob/a05af547f8d1958217ef0dc0028890b1839e6116/src/bandersnatch_filter_plugins/prerelease_name.py#L18C1-L23C6 PRERELEASE_PATTERNS = ( @@ -61,6 +67,13 @@ def exit_handler(signum: int, frame: Optional[FrameType]) -> None: signal.signal(signal.SIGTERM, exit_handler) +def exit_with_futures(futures: dict[Future[Any], Any]) -> NoReturn: + logger.info("Exiting...") + for future in futures: + future.cancel() + sys.exit(1) + + class LocalVersionKV: """ A key-value database wrapper over sqlite3. @@ -146,6 +159,20 @@ def overwrite( raise +def fast_readall(file_path: Path) -> bytes: + """ + Save some extra read(), lseek() and ioctl(). + """ + fd = os.open(file_path, os.O_RDONLY) + if fd < 0: + raise FileNotFoundError(file_path) + try: + contents = os.read(fd, file_path.stat().st_size) + return contents + finally: + os.close(fd) + + def normalize(name: str) -> str: """ See https://peps.python.org/pep-0503/#normalized-names @@ -166,6 +193,22 @@ def remove_dir_with_files(directory: Path) -> None: logger.info("Removed dir %s", directory) +def fast_iterdir( + directory: Path | str, filter_type: Literal["dir", "file"] +) -> Generator[os.DirEntry[str], Any, None]: + """ + iterdir() in pathlib would ignore file type information from getdents64(), + which is not acceptable when you have millions of files in one directory, + and you need to filter out all files/directories. + """ + assert filter_type in ["dir", "file"] + for item in os.scandir(directory): + if filter_type == "dir" and item.is_dir(): + yield item + elif filter_type == "file" and item.is_file(): + yield item + + def get_package_urls_from_index_html(html_path: Path) -> list[str]: """ Get all href (fragments removed) from given simple//index.html contents @@ -185,8 +228,8 @@ def handle_starttag( self.hrefs.append(attr[1]) p = ATagHTMLParser() - with open(html_path) as f: - p.feed(f.read()) + contents = fast_readall(html_path).decode() + p.feed(contents) ret = [] for href in p.hrefs: @@ -201,8 +244,8 @@ def get_package_urls_from_index_json(json_path: Path) -> list[str]: """ Get all urls from given simple//index.v1_json contents """ - with open(json_path) as f: - contents_dict = json.load(f) + contents = fast_readall(json_path) + contents_dict = json.loads(contents) urls = [i["url"] for i in contents_dict["files"]] return urls @@ -213,8 +256,8 @@ def get_package_urls_size_from_index_json(json_path: Path) -> list[tuple[str, in If size is not available, returns size as -1 """ - with open(json_path) as f: - contents_dict = json.load(f) + contents = fast_readall(json_path) + contents_dict = json.loads(contents) ret = [(i["url"], i.get("size", -1)) for i in contents_dict["files"]] return ret @@ -226,15 +269,15 @@ def get_existing_hrefs(package_simple_path: Path) -> Optional[list[str]]: Priority: index.v1_json -> index.html """ - if not package_simple_path.exists(): - return None json_file = package_simple_path / "index.v1_json" html_file = package_simple_path / "index.html" - if json_file.exists(): + try: return get_package_urls_from_index_json(json_file) - if html_file.exists(): - return get_package_urls_from_index_html(html_file) - return None + except FileNotFoundError: + try: + return get_package_urls_from_index_html(html_file) + except FileNotFoundError: + return None class CustomXMLRPCTransport(xmlrpc.client.Transport): @@ -244,9 +287,20 @@ class CustomXMLRPCTransport(xmlrpc.client.Transport): user_agent = USER_AGENT + def make_connection(self, host: tuple[str, dict[str, str]] | str) -> HTTPConnection: + conn = super().make_connection(host) + if conn.timeout is None: + # 2 min timeout + conn.timeout = 120 + return conn + def create_requests_session() -> requests.Session: s = requests.Session() + # hardcode 1min timeout for connect & read for now + # https://requests.readthedocs.io/en/latest/user/advanced/#timeouts + # A hack to overwrite get() method + s.get_orig, s.get = s.get, functools.partial(s.get, timeout=(60, 60)) # type: ignore retries = Retry(total=3, backoff_factor=0.1) s.mount("http://", HTTPAdapter(max_retries=retries)) s.mount("https://", HTTPAdapter(max_retries=retries)) @@ -299,11 +353,25 @@ def get_release_files_from_meta(package_meta: dict) -> list[dict]: @staticmethod def file_url_to_local_url(url: str) -> str: + """ + This function should NOT be used to construct a local Path! + """ parsed = urlparse(url) assert parsed.path.startswith("/packages") prefix = "../.." return prefix + parsed.path + @staticmethod + def file_url_to_local_path(url: str) -> Path: + """ + Unquote() and returns a Path + """ + path = urlparse(url).path + path = unquote(path) + assert path.startswith("/packages") + path = path[1:] + return Path("../..") / path + # Func modified from bandersnatch @classmethod def generate_html_simple_page(cls, package_meta: dict) -> str: @@ -459,6 +527,21 @@ def determine_sync_plan( for i in local_keys - remote_keys: to_remove.append(i) local_keys.remove(i) + # There are always some packages in PyPI's list_packages_with_serial() but actually not there + # Don't count them when comparing len(to_remove) with MAX_DELETION + if len(to_remove) > MAX_DELETION: + logger.error( + "Too many packages to remove (%d > %d)", len(to_remove), MAX_DELETION + ) + logger.info("Some packages that would be removed:") + for p in to_remove[:100]: + logger.info("- %s", p) + for p in to_remove[100:]: + logger.debug("- %s", p) + logger.error( + "Use SHADOWMIRE_MAX_DELETION env to adjust the threshold if you really want to proceed" + ) + sys.exit(2) for i in remote_keys - local_keys: to_update.append(i) for i in local_keys: @@ -480,33 +563,32 @@ def check_and_update( self, package_names: list[str], prerelease_excludes: list[re.Pattern[str]], + json_files: set[str], + packages_pathcache: set[str], compare_size: bool, ) -> bool: - to_update = [] - for package_name in tqdm(package_names, desc="Checking consistency"): - package_jsonmeta_path = self.jsonmeta_dir / package_name - if not package_jsonmeta_path.exists(): + def is_consistent(package_name: str) -> bool: + if package_name not in json_files: + # save a newfstatat() when name already in json_files logger.info("add %s as it does not have json API file", package_name) - to_update.append(package_name) - continue + return False package_simple_path = self.simple_dir / package_name html_simple = package_simple_path / "index.html" htmlv1_simple = package_simple_path / "index.v1_html" json_simple = package_simple_path / "index.v1_json" - if not ( - html_simple.exists() and json_simple.exists() and htmlv1_simple.exists() - ): + try: + # always create index.html symlink, if not exists or not a symlink + if not html_simple.is_symlink(): + html_simple.unlink(missing_ok=True) + html_simple.symlink_to("index.v1_html") + hrefs_html = get_package_urls_from_index_html(htmlv1_simple) + hrefsize_json = get_package_urls_size_from_index_json(json_simple) + except FileNotFoundError: logger.info( - "add %s as it does not have index.html, index.v1_html or index.v1_json", + "add %s as it does not have index.v1_html or index.v1_json", package_name, ) - to_update.append(package_name) - continue - if not html_simple.is_symlink(): - html_simple.unlink() - html_simple.symlink_to("index.v1_html") - hrefs_html = get_package_urls_from_index_html(html_simple) - hrefsize_json = get_package_urls_size_from_index_json(json_simple) + return False if ( hrefs_html is None or hrefsize_json is None @@ -514,36 +596,67 @@ def check_and_update( ): # something unexpected happens... logger.info("add %s as its indexes are not consistent", package_name) - to_update.append(package_name) - continue + return False # OK, check if all hrefs have corresponding files if self.sync_packages: - should_update = False for href, size in hrefsize_json: - dest = Path(normpath(package_simple_path / href)) - if not dest.exists(): + relative_path = unquote(href) + dest_pathstr = normpath(package_simple_path / relative_path) + try: + # Fast shortcut to avoid stat() it + if dest_pathstr not in packages_pathcache: + raise FileNotFoundError + if compare_size and size != -1: + dest = Path(dest_pathstr) + # So, do stat() for real only when we need to do so, + # have a size, and it really exists in pathcache. + dest_stat = dest.stat() + dest_size = dest_stat.st_size + if dest_size != size: + logger.info( + "add %s as its local size %s != %s", + package_name, + dest_size, + size, + ) + return False + except FileNotFoundError: logger.info("add %s as it's missing packages", package_name) - should_update = True - break - if compare_size and size != -1: - dest_size = dest.stat().st_size - if dest_size != size: - logger.info( - "add %s as its local size %s != %s", - package_name, - dest_size, - size, - ) - should_update = True - break - if should_update: - to_update.append(package_name) + return False + + return True + + to_update = [] + with ThreadPoolExecutor(max_workers=IOWORKERS) as executor: + futures = { + executor.submit(is_consistent, package_name): package_name + for package_name in package_names + } + try: + for future in tqdm( + as_completed(futures), + total=len(package_names), + desc="Checking consistency", + ): + package_name = futures[future] + try: + consistent = future.result() + if not consistent: + to_update.append(package_name) + except Exception: + logger.warning( + "%s generated an exception", package_name, exc_info=True + ) + raise + except: + exit_with_futures(futures) + logger.info("%s packages to update in check_and_update()", len(to_update)) return self.parallel_update(to_update, prerelease_excludes) def parallel_update( - self, package_names: list, prerelease_excludes: list[re.Pattern[str]] + self, package_names: list[str], prerelease_excludes: list[re.Pattern[str]] ) -> bool: success = True with ThreadPoolExecutor(max_workers=WORKERS) as executor: @@ -566,7 +679,7 @@ def parallel_update( if serial: self.local_db.set(package_name, serial) except Exception as e: - if isinstance(e, (ExitProgramException, KeyboardInterrupt)): + if isinstance(e, (KeyboardInterrupt)): raise logger.warning( "%s generated an exception", package_name, exc_info=True @@ -576,10 +689,7 @@ def parallel_update( logger.info("dumping local db...") self.local_db.dump_json() except (ExitProgramException, KeyboardInterrupt): - logger.info("Get ExitProgramException or KeyboardInterrupt, exiting...") - for future in futures: - future.cancel() - sys.exit(1) + exit_with_futures(futures) return success def do_sync_plan( @@ -662,6 +772,27 @@ def finalize(self) -> None: f.write(" \n") self.local_db.dump_json() + def skip_this_package(self, i: dict, dest: Path) -> bool: + """ + A helper function for subclasses implementing do_update(). + As existence check is also done with stat(), this would not bring extra I/O overhead. + Returns if skip this package or not. + """ + try: + dest_size = dest.stat().st_size + i_size = i.get("size", -1) + if i_size == -1: + return True + if dest_size == i_size: + return True + logger.warning( + "file %s exists locally, but size does not match with upstream, so it would still be downloaded.", + dest, + ) + return False + except FileNotFoundError: + return False + def download( session: requests.Session, url: str, dest: Path @@ -746,7 +877,8 @@ def do_update( self.pypi.file_url_to_local_url(i["url"]) for i in release_files ] should_remove = list(set(existing_hrefs) - set(remote_hrefs)) - for p in should_remove: + for href in should_remove: + p = unquote(href) logger.info("removing file %s (if exists)", p) package_path = Path(normpath(package_simple_path / p)) package_path.unlink(missing_ok=True) @@ -754,12 +886,13 @@ def do_update( url = i["url"] dest = Path( normpath( - package_simple_path / self.pypi.file_url_to_local_url(i["url"]) + package_simple_path / self.pypi.file_url_to_local_path(i["url"]) ) ) logger.info("downloading file %s -> %s", url, dest) - if dest.exists(): + if self.skip_this_package(i, dest): continue + dest.parent.mkdir(parents=True, exist_ok=True) success, _resp = download(self.session, url, dest) if not success: @@ -847,16 +980,19 @@ def do_update( release_files = PyPI.get_release_files_from_meta(meta) remote_hrefs = [PyPI.file_url_to_local_url(i["url"]) for i in release_files] should_remove = list(set(existing_hrefs) - set(remote_hrefs)) - for p in should_remove: + for href in should_remove: + p = unquote(href) logger.info("removing file %s (if exists)", p) package_path = Path(normpath(package_simple_path / p)) package_path.unlink(missing_ok=True) package_simple_url = urljoin(self.upstream, f"simple/{package_name}/") - for href in remote_hrefs: + for i in release_files: + href = PyPI.file_url_to_local_url(i["url"]) + path = PyPI.file_url_to_local_path(i["url"]) url = urljoin(package_simple_url, href) - dest = Path(normpath(package_simple_path / href)) + dest = Path(normpath(package_simple_path / path)) logger.info("downloading file %s -> %s", url, dest) - if dest.exists(): + if self.skip_this_package(i, dest): continue dest.parent.mkdir(parents=True, exist_ok=True) success, resp = download(self.session, url, dest) @@ -878,14 +1014,13 @@ def do_update( return last_serial -def get_local_serial(package_meta_path: Path) -> Optional[int]: +def get_local_serial(package_meta_direntry: os.DirEntry[str]) -> Optional[int]: """ Accepts /json/ as package_meta_path """ - package_name = package_meta_path.name + package_name = package_meta_direntry.name try: - with open(package_meta_path) as f: - contents = f.read() + contents = fast_readall(Path(package_meta_direntry.path)) except FileNotFoundError: logger.warning("%s does not have JSON metadata, skipping", package_name) return None @@ -1052,13 +1187,32 @@ def genlocal(ctx: click.Context) -> None: local = {} json_dir = basedir / "json" logger.info("Iterating all items under %s", json_dir) - dir_items = [d for d in json_dir.iterdir() if d.is_file()] + dir_items = [d for d in fast_iterdir(json_dir, "file")] logger.info("Detected %s packages in %s in total", len(dir_items), json_dir) - for package_metapath in tqdm(dir_items, desc="Reading packages from json/"): - package_name = package_metapath.name - serial = get_local_serial(package_metapath) - if serial: - local[package_name] = serial + with ThreadPoolExecutor(max_workers=IOWORKERS) as executor: + futures = { + executor.submit(get_local_serial, package_metapath): package_metapath + for package_metapath in dir_items + } + try: + for future in tqdm( + as_completed(futures), + total=len(dir_items), + desc="Reading packages from json/", + ): + package_name = futures[future].name + try: + serial = future.result() + if serial: + local[package_name] = serial + except Exception as e: + if isinstance(e, (KeyboardInterrupt)): + raise + logger.warning( + "%s generated an exception", package_name, exc_info=True + ) + except (ExitProgramException, KeyboardInterrupt): + exit_with_futures(futures) logger.info( "%d out of %d packages have valid serial number", len(local), len(dir_items) ) @@ -1100,8 +1254,8 @@ def verify( logger.info("====== Step 1. Remove packages NOT in local db ======") local_names = set(local_db.keys()) - simple_dirs = {i.name for i in (basedir / "simple").iterdir() if i.is_dir()} - json_files = {i.name for i in (basedir / "json").iterdir() if i.is_file()} + simple_dirs = {i.name for i in fast_iterdir((basedir / "simple"), "dir")} + json_files = {i.name for i in fast_iterdir((basedir / "json"), "file")} not_in_local = (simple_dirs | json_files) - local_names logger.info( "%d out of %d local packages NOT in local db", @@ -1133,37 +1287,101 @@ def verify( # After some removal, local_names is changed. local_names = set(local_db.keys()) + logger.info("====== Step 3. Caching packages/ dirtree in memory for Step 4 & 5.") + packages_pathcache: set[str] = set() + with ThreadPoolExecutor(max_workers=IOWORKERS) as executor: + + def packages_iterate(first_dirname: str, position: int) -> list[str]: + with tqdm( + desc=f"Iterating packages/{first_dirname}/*/*/*", position=position + ) as pb: + res = [] + for d1 in fast_iterdir(basedir / "packages" / first_dirname, "dir"): + for d2 in fast_iterdir(d1.path, "dir"): + for file in fast_iterdir(d2.path, "file"): + pb.update(1) + res.append(file.path) + return res + + futures = { + executor.submit(packages_iterate, first_dir.name, idx % IOWORKERS): first_dir.name # type: ignore + for idx, first_dir in enumerate(fast_iterdir((basedir / "packages"), "dir")) + } + try: + for future in as_completed(futures): + sname = futures[future] + try: + for p in future.result(): + packages_pathcache.add(p) + except Exception as e: + if isinstance(e, (KeyboardInterrupt)): + raise + logger.warning("%s generated an exception", sname, exc_info=True) + success = False + except (ExitProgramException, KeyboardInterrupt): + exit_with_futures(futures) + logger.info( - "====== Step 3. Make sure all local indexes are valid, and (if --sync-packages) have valid local package files ======" + "====== Step 4. Make sure all local indexes are valid, and (if --sync-packages) have valid local package files ======" ) success = syncer.check_and_update( - list(local_names), prerelease_excludes, compare_size + list(local_names), + prerelease_excludes, + json_files, + packages_pathcache, + compare_size, ) syncer.finalize() logger.info( - "====== Step 4. Remove any unreferenced files in `packages` folder ======" + "====== Step 5. Remove any unreferenced files in `packages` folder ======" ) - ref_set = set() - for sname in tqdm(simple_dirs, desc="Iterating simple/ directory"): - sd = basedir / "simple" / sname - hrefs = get_existing_hrefs(sd) - hrefs = [] if hrefs is None else hrefs - for i in hrefs: - # use normpath, which is much faster than pathlib resolve(), as it does not need to access fs - # we could make sure no symlinks could affect this here - np = normpath(sd / i) - logger.debug("add to ref_set: %s", np) - ref_set.add(np) - for file in tqdm( - (basedir / "packages").glob("*/*/*/*"), desc="Iterating packages/*/*/*/*" - ): - # basedir is absolute, so file is also absolute - # just convert to str to match normpath result - logger.debug("find file %s", file) - if str(file) not in ref_set: - logger.info("removing unreferenced file %s", file) - file.unlink() + ref_set: set[str] = set() + with ThreadPoolExecutor(max_workers=IOWORKERS) as executor: + # Part 1: iterate simple/ + def iterate_simple(sname: str) -> list[str]: + sd = basedir / "simple" / sname + hrefs = get_existing_hrefs(sd) + hrefs = [] if hrefs is None else hrefs + nps = [] + for href in hrefs: + i = unquote(href) + # use normpath, which is much faster than pathlib resolve(), as it does not need to access fs + # we could make sure no symlinks could affect this here + np = normpath(sd / i) + logger.debug("add to ref_set: %s", np) + nps.append(np) + return nps + + # MyPy does not enjoy same variable name with different types, even when --allow-redefinition + # Ignore here to make mypy happy + futures = { + executor.submit(iterate_simple, sname): sname for sname in simple_dirs # type: ignore + } + try: + for future in tqdm( + as_completed(futures), + total=len(simple_dirs), + desc="Iterating simple/ directory", + ): + sname = futures[future] + try: + nps = future.result() + for np in nps: + ref_set.add(np) + except Exception as e: + if isinstance(e, (KeyboardInterrupt)): + raise + logger.warning("%s generated an exception", sname, exc_info=True) + success = False + except (ExitProgramException, KeyboardInterrupt): + exit_with_futures(futures) + + # Part 2: handling packages + for path in tqdm(packages_pathcache, desc="Iterating path cache"): + if path not in ref_set: + logger.info("removing unreferenced file %s", path) + Path(path).unlink() logger.info("Verification finished. Success: %s", success)