From 0dfd1a8b7d81bbcdd616540e1e5205562d494940 Mon Sep 17 00:00:00 2001 From: Thomas De Bonnet <45205349+Thutmose3@users.noreply.github.com> Date: Mon, 21 Aug 2023 01:13:13 +0200 Subject: [PATCH] Improve readability of download-model.py (#3497) --- download-model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/download-model.py b/download-model.py index e1afa9ef33..a65f82c710 100644 --- a/download-model.py +++ b/download-model.py @@ -24,14 +24,14 @@ class ModelDownloader: def __init__(self, max_retries=5): - self.s = requests.Session() + self.session = requests.Session() if max_retries: - self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) - self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) + self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) + self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: - self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) + self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) if os.getenv('HF_TOKEN') is not None: - self.s.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'} + self.session.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'} def sanitize_model_and_branch_names(self, model, branch): if model[-1] == '/': @@ -62,7 +62,7 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False): is_lora = False while True: url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "") - r = self.s.get(url, timeout=10) + r = self.session.get(url, timeout=10) r.raise_for_status() content = r.content @@ -136,7 +136,7 @@ def get_single_file(self, url, output_folder, start_from_scratch=False): if output_path.exists() and not start_from_scratch: # Check if the file has already been downloaded completely - r = self.s.get(url, stream=True, timeout=10) + r = self.session.get(url, stream=True, timeout=10) total_size = int(r.headers.get('content-length', 0)) if output_path.stat().st_size >= total_size: return @@ -145,7 +145,7 @@ def get_single_file(self, url, output_folder, start_from_scratch=False): headers = {'Range': f'bytes={output_path.stat().st_size}-'} mode = 'ab' - with self.s.get(url, stream=True, headers=headers, timeout=10) as r: + with self.session.get(url, stream=True, headers=headers, timeout=10) as r: r.raise_for_status() # Do not continue the download if the request was unsuccessful total_size = int(r.headers.get('content-length', 0)) block_size = 1024 * 1024 # 1MB