Skip to content

Commit

Permalink
Run formatter on some Python scripts
Browse files Browse the repository at this point in the history
Signed-off-by: Harry Chen <[email protected]>
  • Loading branch information
Harry-Chen committed Aug 17, 2024
1 parent f8afa1f commit 5f4bc1c
Show file tree
Hide file tree
Showing 6 changed files with 449 additions and 258 deletions.
8 changes: 1 addition & 7 deletions adoptium.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
#!/usr/bin/env python3
import hashlib
import traceback
import json
import os
import re
import shutil
import subprocess as sp
import tempfile
import argparse
import time
from email.utils import parsedate_to_datetime
from pathlib import Path
from typing import List, Set, Tuple, IO
from typing import Set
import requests

DOWNLOAD_TIMEOUT = int(os.getenv('DOWNLOAD_TIMEOUT', '1800'))
Expand Down
156 changes: 96 additions & 60 deletions anaconda.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

WORKING_DIR = os.getenv("TUNASYNC_WORKING_DIR")

# fmt: off
CONDA_REPOS = ("main", "free", "r", "msys2")
CONDA_ARCHES = (
"noarch", "linux-64", "linux-32", "linux-aarch64", "linux-armv6l", "linux-armv7l",
Expand Down Expand Up @@ -72,6 +73,7 @@
EXCLUDED_PACKAGES = (
"pytorch-nightly", "pytorch-nightly-cpu", "ignite-nightly",
)
# fmt: on

# connect and read timeout value
TIMEOUT_OPTION = (7, 10)
Expand All @@ -84,63 +86,74 @@
format="[%(asctime)s] [%(levelname)s] %(message)s",
)

def sizeof_fmt(num, suffix='iB'):
for unit in ['','K','M','G','T','P','E','Z']:

def sizeof_fmt(num, suffix="iB"):
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
if abs(num) < 1024.0:
return "%3.2f%s%s" % (num, unit, suffix)
num /= 1024.0
return "%.2f%s%s" % (num, 'Y', suffix)
return "%.2f%s%s" % (num, "Y", suffix)


def md5_check(file: Path, md5: str = None):
m = hashlib.md5()
with file.open('rb') as f:
with file.open("rb") as f:
while True:
buf = f.read(1*1024*1024)
buf = f.read(1 * 1024 * 1024)
if not buf:
break
m.update(buf)
return m.hexdigest() == md5


def sha256_check(file: Path, sha256: str = None):
m = hashlib.sha256()
with file.open('rb') as f:
with file.open("rb") as f:
while True:
buf = f.read(1*1024*1024)
buf = f.read(1 * 1024 * 1024)
if not buf:
break
m.update(buf)
return m.hexdigest() == sha256


def curl_download(remote_url: str, dst_file: Path, sha256: str = None, md5: str = None):
sp.check_call([
"curl", "-o", str(dst_file),
"-sL", "--remote-time", "--show-error",
"--fail", "--retry", "10", "--speed-time", "15",
"--speed-limit", "5000", remote_url,
])
# fmt: off
sp.check_call(
[
"curl", "-o", str(dst_file),
"-sL", "--remote-time", "--show-error",
"--fail", "--retry", "10",
"--speed-time", "15",
"--speed-limit", "5000",
remote_url,
]
)
# fmt: on
if sha256 and (not sha256_check(dst_file, sha256)):
return "SHA256 mismatch"
if md5 and (not md5_check(dst_file, md5)):
return "MD5 mismatch"


def sync_repo(repo_url: str, local_dir: Path, tmpdir: Path, delete: bool, remove_legacy: bool):
def sync_repo(
repo_url: str, local_dir: Path, tmpdir: Path, delete: bool, remove_legacy: bool
):
logging.info("Start syncing {}".format(repo_url))
local_dir.mkdir(parents=True, exist_ok=True)

repodata_url = repo_url + '/repodata.json'
bz2_repodata_url = repo_url + '/repodata.json.bz2'
repodata_url = repo_url + "/repodata.json"
bz2_repodata_url = repo_url + "/repodata.json.bz2"
# https://github.com/conda/conda/issues/13256, from conda 24.1.x
zst_repodata_url = repo_url + '/repodata.json.zst'
zst_repodata_url = repo_url + "/repodata.json.zst"
# https://docs.conda.io/projects/conda-build/en/latest/release-notes.html
# "current_repodata.json" - like repodata.json, but only has the newest version of each file
current_repodata_url = repo_url + '/current_repodata.json'
current_repodata_url = repo_url + "/current_repodata.json"

tmp_repodata = tmpdir / "repodata.json"
tmp_bz2_repodata = tmpdir / "repodata.json.bz2"
tmp_zst_repodata = tmpdir / "repodata.json.zst"
tmp_current_repodata = tmpdir / 'current_repodata.json'
tmp_current_repodata = tmpdir / "current_repodata.json"

curl_download(repodata_url, tmp_repodata)
curl_download(bz2_repodata_url, tmp_bz2_repodata)
Expand All @@ -158,31 +171,33 @@ def sync_repo(repo_url: str, local_dir: Path, tmpdir: Path, delete: bool, remove

remote_filelist = []
total_size = 0
legacy_packages = repodata['packages']
legacy_packages = repodata["packages"]
conda_packages = repodata.get("packages.conda", {})
if remove_legacy:
# https://github.com/anaconda/conda/blob/0dbf85e0546e0b0dc060c8265ec936591ccbe980/conda/core/subdir_data.py#L440-L442
use_legacy_packages = set(legacy_packages.keys()) - set(k[:-6] + ".tar.bz2" for k in conda_packages.keys())
use_legacy_packages = set(legacy_packages.keys()) - set(
k[:-6] + ".tar.bz2" for k in conda_packages.keys()
)
legacy_packages = {k: legacy_packages[k] for k in use_legacy_packages}
packages = {**legacy_packages, **conda_packages}

for filename, meta in packages.items():
if meta['name'] in EXCLUDED_PACKAGES:
if meta["name"] in EXCLUDED_PACKAGES:
continue

file_size = meta['size']
file_size = meta["size"]
# prefer sha256 over md5
sha256 = None
md5 = None
if 'sha256' in meta:
sha256 = meta['sha256']
elif 'md5' in meta:
md5 = meta['md5']
if "sha256" in meta:
sha256 = meta["sha256"]
elif "md5" in meta:
md5 = meta["md5"]
total_size += file_size

pkg_url = '/'.join([repo_url, filename])
pkg_url = "/".join([repo_url, filename])
dst_file = local_dir / filename
dst_file_wip = local_dir / ('.downloading.' + filename)
dst_file_wip = local_dir / (".downloading." + filename)
remote_filelist.append(dst_file)

if dst_file.is_file():
Expand All @@ -202,7 +217,7 @@ def sync_repo(repo_url: str, local_dir: Path, tmpdir: Path, delete: bool, remove
if err is None:
dst_file_wip.rename(dst_file)
except sp.CalledProcessError:
err = 'CalledProcessError'
err = "CalledProcessError"
if err is None:
break
logging.error("Failed to download {}: {}".format(filename, err))
Expand All @@ -223,68 +238,79 @@ def sync_repo(repo_url: str, local_dir: Path, tmpdir: Path, delete: bool, remove
tmp_current_repodata_gz_gened = False
if tmp_current_repodata.is_file():
if os.path.getsize(tmp_current_repodata) > GEN_METADATA_JSON_GZIP_THRESHOLD:
sp.check_call(["gzip", "--no-name", "--keep", "--", str(tmp_current_repodata)])
shutil.move(str(tmp_current_repodata) + ".gz", str(local_dir / "current_repodata.json.gz"))
sp.check_call(
["gzip", "--no-name", "--keep", "--", str(tmp_current_repodata)]
)
shutil.move(
str(tmp_current_repodata) + ".gz",
str(local_dir / "current_repodata.json.gz"),
)
tmp_current_repodata_gz_gened = True
shutil.move(str(tmp_current_repodata), str(
local_dir / "current_repodata.json"))
shutil.move(str(tmp_current_repodata), str(local_dir / "current_repodata.json"))
if not tmp_current_repodata_gz_gened:
# If the gzip file is not generated, remove the dangling gzip archive
Path(local_dir / "current_repodata.json.gz").unlink(missing_ok=True)

if delete:
local_filelist = []
delete_count = 0
for i in local_dir.glob('*.tar.bz2'):
for i in local_dir.glob("*.tar.bz2"):
local_filelist.append(i)
for i in local_dir.glob('*.conda'):
for i in local_dir.glob("*.conda"):
local_filelist.append(i)
for i in set(local_filelist) - set(remote_filelist):
logging.info("Deleting {}".format(i))
i.unlink()
delete_count += 1
logging.info("{} files deleted".format(delete_count))

logging.info("{}: {} files, {} in total".format(
repodata_url, len(remote_filelist), sizeof_fmt(total_size)))
logging.info(
"{}: {} files, {} in total".format(
repodata_url, len(remote_filelist), sizeof_fmt(total_size)
)
)
return total_size


def sync_installer(repo_url, local_dir: Path):
logging.info("Start syncing {}".format(repo_url))
local_dir.mkdir(parents=True, exist_ok=True)
full_scan = random.random() < 0.1 # Do full version check less frequently
full_scan = random.random() < 0.1 # Do full version check less frequently

def remote_list():
r = requests.get(repo_url, timeout=TIMEOUT_OPTION)
d = pq(r.content)
for tr in d('table').find('tr'):
tds = pq(tr).find('td')
for tr in d("table").find("tr"):
tds = pq(tr).find("td")
if len(tds) != 4:
continue
fname = tds[0].find('a').text
fname = tds[0].find("a").text
sha256 = tds[3].text
if sha256 == '<directory>' or len(sha256) != 64:
if sha256 == "<directory>" or len(sha256) != 64:
continue
yield (fname, sha256)

for filename, sha256 in remote_list():
pkg_url = "/".join([repo_url, filename])
dst_file = local_dir / filename
dst_file_wip = local_dir / ('.downloading.' + filename)
dst_file_wip = local_dir / (".downloading." + filename)

if dst_file.is_file():
r = requests.head(pkg_url, allow_redirects=True, timeout=TIMEOUT_OPTION)
len_avail = 'content-length' in r.headers
len_avail = "content-length" in r.headers
if len_avail:
remote_filesize = int(r.headers['content-length'])
remote_date = parsedate_to_datetime(r.headers['last-modified'])
remote_filesize = int(r.headers["content-length"])
remote_date = parsedate_to_datetime(r.headers["last-modified"])
stat = dst_file.stat()
local_filesize = stat.st_size
local_mtime = stat.st_mtime

# Do content verification on ~5% of files (see issue #25)
if (not len_avail or remote_filesize == local_filesize) and remote_date.timestamp() == local_mtime and \
(random.random() < 0.95 or sha256_check(dst_file, sha256)):
if (
(not len_avail or remote_filesize == local_filesize)
and remote_date.timestamp() == local_mtime
and (random.random() < 0.95 or sha256_check(dst_file, sha256))
):
logging.info("Skipping {}".format(filename))

# Stop the scanning if the most recent version is present
Expand All @@ -299,25 +325,31 @@ def remote_list():

for retry in range(3):
logging.info("Downloading {}".format(filename))
err = ''
err = ""
try:
err = curl_download(pkg_url, dst_file_wip, sha256=sha256)
if err is None:
dst_file_wip.rename(dst_file)
except sp.CalledProcessError:
err = 'CalledProcessError'
err = "CalledProcessError"
if err is None:
break
logging.error("Failed to download {}: {}".format(filename, err))


def main():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--working-dir", default=WORKING_DIR)
parser.add_argument("--delete", action='store_true',
help='delete unreferenced package files')
parser.add_argument("--remove-legacy", action='store_true',
help='delete legacy packages which have conda counterpart. Requires client conda >= 4.7.0')
parser.add_argument(
"--delete", action="store_true", help="delete unreferenced package files"
)
parser.add_argument(
"--remove-legacy",
action="store_true",
help="delete legacy packages which have conda counterpart. Requires client conda >= 4.7.0",
)
args = parser.parse_args()

if args.working_dir is None:
Expand All @@ -336,7 +368,8 @@ def main():
try:
sync_installer(remote_url, local_dir)
size_statistics += sum(
f.stat().st_size for f in local_dir.glob('*') if f.is_file())
f.stat().st_size for f in local_dir.glob("*") if f.is_file()
)
except Exception:
logging.exception("Failed to sync installers of {}".format(dist))
success = False
Expand All @@ -348,8 +381,9 @@ def main():

tmpdir = tempfile.mkdtemp()
try:
size_statistics += sync_repo(remote_url,
local_dir, Path(tmpdir), args.delete, args.remove_legacy)
size_statistics += sync_repo(
remote_url, local_dir, Path(tmpdir), args.delete, args.remove_legacy
)
except Exception:
logging.exception("Failed to sync repo: {}/{}".format(repo, arch))
success = False
Expand All @@ -362,8 +396,9 @@ def main():

tmpdir = tempfile.mkdtemp()
try:
size_statistics += sync_repo(remote_url,
local_dir, Path(tmpdir), args.delete, args.remove_legacy)
size_statistics += sync_repo(
remote_url, local_dir, Path(tmpdir), args.delete, args.remove_legacy
)
except Exception:
logging.exception("Failed to sync repo: {}".format(repo))
success = False
Expand All @@ -374,6 +409,7 @@ def main():
if not success:
sys.exit(1)


if __name__ == "__main__":
main()

Expand Down
Loading

0 comments on commit 5f4bc1c

Please sign in to comment.