Skip to content

Commit

Permalink
Improve readability of download-model.py (#3497)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thutmose3 committed Aug 20, 2023
1 parent 457fedf commit 0dfd1a8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions download-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

class ModelDownloader:
def __init__(self, max_retries=5):
self.s = requests.Session()
self.session = requests.Session()
if max_retries:
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
if os.getenv('HF_TOKEN') is not None:
self.s.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'}
self.session.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'}

def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/':
Expand Down Expand Up @@ -62,7 +62,7 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False):
is_lora = False
while True:
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
r = self.s.get(url, timeout=10)
r = self.session.get(url, timeout=10)
r.raise_for_status()
content = r.content

Expand Down Expand Up @@ -136,7 +136,7 @@ def get_single_file(self, url, output_folder, start_from_scratch=False):
if output_path.exists() and not start_from_scratch:

# Check if the file has already been downloaded completely
r = self.s.get(url, stream=True, timeout=10)
r = self.session.get(url, stream=True, timeout=10)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
Expand All @@ -145,7 +145,7 @@ def get_single_file(self, url, output_folder, start_from_scratch=False):
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'

with self.s.get(url, stream=True, headers=headers, timeout=10) as r:
with self.session.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status() # Do not continue the download if the request was unsuccessful
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB
Expand Down

0 comments on commit 0dfd1a8

Please sign in to comment.