diff --git a/.gitignore b/.gitignore index 532357fe..6da94a72 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ build ramalama/*.patch dist .#* +venv/ diff --git a/ramalama/common.py b/ramalama/common.py index 182b0559..f9933ab1 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -7,6 +7,7 @@ import string import subprocess import sys +import urllib.request x = False @@ -154,3 +155,67 @@ def default_image(): def genname(): return "ramalama_" + "".join(random.choices(string.ascii_letters + string.digits, k=10)) + + +def download_file(url, dest_path, headers=None, show_progress=True): + try: + from tqdm import tqdm + except FileNotFoundError: + raise NotImplementedError( + """\ +Ollama models requires the tqdm modules. +This module can be installed via PyPi tools like pip, pip3, pipx or via +distribution package managers like dnf or apt. Example: +pip install tqdm +""" + ) + + # Check if partially downloaded file exists + if os.path.exists(dest_path): + downloaded_size = os.path.getsize(dest_path) + else: + downloaded_size = 0 + + request = urllib.request.Request(url, headers=headers or {}) + request.headers["Range"] = f"bytes={downloaded_size}-" # Set range header + + filename = dest_path.split('/')[-1] + + bar_format = "Pulling {desc}: {percentage:3.0f}% ▕{bar:20}▏ {n_fmt}/{total_fmt} {rate_fmt} {remaining}" + try: + with urllib.request.urlopen(request) as response: + total_size = int(response.headers.get("Content-Length", 0)) + downloaded_size + chunk_size = 8192 # 8 KB chunks + + with open(dest_path, "ab") as file: + if show_progress: + with tqdm( + desc=filename, + total=total_size, + initial=downloaded_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + bar_format=bar_format, + ascii=True, + ) as progress_bar: + while True: + chunk = response.read(chunk_size) + if not chunk: + break + file.write(chunk) + progress_bar.update(len(chunk)) + else: + # Download file without showing progress + while True: + chunk = response.read(chunk_size) + if not chunk: + break + file.write(chunk) + except urllib.error.HTTPError as e: + if e.code == 416: + if show_progress: + # If we get a 416 error, it means the file is fully downloaded + print(f"File {url} already fully downloaded.") + else: + raise e diff --git a/ramalama/huggingface.py b/ramalama/huggingface.py index ff49718e..af28d1a6 100644 --- a/ramalama/huggingface.py +++ b/ramalama/huggingface.py @@ -1,15 +1,37 @@ import os -from ramalama.common import run_cmd, exec_cmd +import urllib.request +from ramalama.common import run_cmd, exec_cmd, download_file, verify_checksum from ramalama.model import Model missing_huggingface = """ -Huggingface models requires the huggingface-cli and tqdm modules. -These modules can be installed via PyPi tools like pip, pip3, pipx or via +Optional: Huggingface models require the huggingface-cli and tqdm modules. +These modules can be installed via PyPi tools like pip, pip3, pipx, or via distribution package managers like dnf or apt. Example: pip install huggingface_hub tqdm """ +def is_huggingface_cli_available(): + """Check if huggingface-cli is available on the system.""" + try: + run_cmd(["huggingface-cli", "version"]) + return True + except FileNotFoundError: + print("huggingface-cli not found. Some features may be limited.\n" + missing_huggingface) + return False + + +def fetch_checksum_from_api(url): + """Fetch the SHA-256 checksum from the model's metadata API.""" + with urllib.request.urlopen(url) as response: + data = response.read().decode() + # Extract the SHA-256 checksum from the `oid sha256` line + for line in data.splitlines(): + if line.startswith("oid sha256:"): + return line.split(":", 1)[1].strip() + raise ValueError("SHA-256 checksum not found in the API response.") + + class Huggingface(Model): def __init__(self, model): model = model.removeprefix("huggingface://") @@ -17,47 +39,64 @@ def __init__(self, model): super().__init__(model) self.type = "HuggingFace" split = self.model.rsplit("/", 1) - self.directory = "" - if len(split) > 1: - self.directory = split[0] - self.filename = split[1] - else: - self.filename = split[0] + self.directory = split[0] if len(split) > 1 else "" + self.filename = split[1] if len(split) > 1 else split[0] + self.hf_cli_available = is_huggingface_cli_available() def login(self, args): + if not self.hf_cli_available: + print("huggingface-cli not available, skipping login.") + return conman_args = ["huggingface-cli", "login"] if args.token: conman_args.extend(["--token", args.token]) - try: - self.exec(conman_args) - except FileNotFoundError as e: - raise NotImplementedError( - """\ -%s -%s""" - % (str(e).strip("'"), missing_huggingface) - ) + self.exec(conman_args) def logout(self, args): + if not self.hf_cli_available: + print("huggingface-cli not available, skipping logout.") + return conman_args = ["huggingface-cli", "logout"] if args.token: conman_args.extend(["--token", args.token]) - conman_args.extend(args) self.exec(conman_args) def path(self, args): return self.symlink_path(args) def pull(self, args): - relative_target_path = "" symlink_path = self.symlink_path(args) + directory_path = os.path.join(args.store, "repos", "huggingface", self.directory, self.filename) + os.makedirs(directory_path, exist_ok=True) + + symlink_dir = os.path.dirname(symlink_path) + os.makedirs(symlink_dir, exist_ok=True) + + # Fetch the SHA-256 checksum from the API + checksum_api_url = f"https://huggingface.co/{self.directory}/raw/main/{self.filename}" + sha256_checksum = fetch_checksum_from_api(checksum_api_url) - gguf_path = self.download(args.store) - relative_target_path = os.path.relpath(gguf_path.rstrip(), start=os.path.dirname(symlink_path)) - directory = f"{args.store}/models/huggingface/{self.directory}" - os.makedirs(directory, exist_ok=True) + target_path = os.path.join(directory_path, f"sha256:{sha256_checksum}") + + if os.path.exists(target_path) and verify_checksum(target_path): + relative_target_path = os.path.relpath(target_path, start=os.path.dirname(symlink_path)) + if not self.check_valid_symlink_path(relative_target_path, symlink_path): + run_cmd(["ln", "-sf", relative_target_path, symlink_path], debug=args.debug) + return symlink_path - if os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path: + # Download the model file to the target path + url = f"https://huggingface.co/{self.directory}/resolve/main/{self.filename}" + download_file(url, target_path, headers={}, show_progress=True) + + if not verify_checksum(target_path): + print(f"Checksum mismatch for {target_path}, retrying download...") + os.remove(target_path) + download_file(url, target_path, headers={}, show_progress=True) + if not verify_checksum(target_path): + raise ValueError(f"Checksum verification failed for {target_path}") + + relative_target_path = os.path.relpath(target_path, start=os.path.dirname(symlink_path)) + if self.check_valid_symlink_path(relative_target_path, symlink_path): # Symlink is already correct, no need to update it return symlink_path @@ -66,67 +105,34 @@ def pull(self, args): return symlink_path def push(self, source, args): - try: - proc = run_cmd( - [ - "huggingface-cli", - "upload", - "--repo-type", - "model", - self.directory, - self.filename, - "--cache-dir", - args.store + "/repos/huggingface/.cache", - "--local-dir", - args.store + "/repos/huggingface/" + self.directory, - ], - debug=args.debug, - ) - return proc.stdout.decode("utf-8") - except FileNotFoundError as e: - raise NotImplementedError( - """\ - %s - %s""" - % (str(e).strip("'"), missing_huggingface) - ) + if not self.hf_cli_available: + print("huggingface-cli not available, skipping push.") + return + proc = run_cmd( + [ + "huggingface-cli", + "upload", + "--repo-type", + "model", + self.directory, + self.filename, + "--cache-dir", + os.path.join(args.store, "repos", "huggingface", ".cache"), + "--local-dir", + os.path.join(args.store, "repos", "huggingface", self.directory), + ], + debug=args.debug, + ) + return proc.stdout.decode("utf-8") def symlink_path(self, args): - return f"{args.store}/models/huggingface/{self.directory}/{self.filename}" + return os.path.join(args.store, "models", "huggingface", self.directory, self.filename) + + def check_valid_symlink_path(self, relative_target_path, symlink_path): + return os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path def exec(self, args): try: exec_cmd(args, args.debug) except FileNotFoundError as e: - raise NotImplementedError( - """\ -%s - -%s -""" - % str(e).strip("'"), - missing_huggingface, - ) - - def download(self, store): - try: - proc = run_cmd( - [ - "huggingface-cli", - "download", - self.directory, - self.filename, - "--cache-dir", - store + "/repos/huggingface/.cache", - "--local-dir", - store + "/repos/huggingface/" + self.directory, - ] - ) - return proc.stdout.decode("utf-8") - except FileNotFoundError as e: - raise NotImplementedError( - """\ - %s - %s""" - % (str(e).strip("'"), missing_huggingface) - ) + print(f"{str(e).strip()}\n{missing_huggingface}") diff --git a/ramalama/model.py b/ramalama/model.py index 1813bcf9..c80c28ed 100644 --- a/ramalama/model.py +++ b/ramalama/model.py @@ -63,7 +63,7 @@ def garbage_collection(self, args): file_has_a_symlink = False for file in files: file_path = os.path.join(root, file) - if (repo == "ollama" and file.startswith("sha256:")) or file.endswith(".gguf"): + if file.startswith("sha256:") or file.endswith(".gguf"): file_path = os.path.join(root, file) for model_root, model_dirs, model_files in os.walk(model_dir): for model_file in model_files: diff --git a/ramalama/ollama.py b/ramalama/ollama.py index 0cdc704b..c977e4ff 100644 --- a/ramalama/ollama.py +++ b/ramalama/ollama.py @@ -1,72 +1,9 @@ import os import urllib.request import json -from ramalama.common import run_cmd, verify_checksum +from ramalama.common import run_cmd, verify_checksum, download_file from ramalama.model import Model -bar_format = "Pulling {desc}: {percentage:3.0f}% ▕{bar:20}▏ {n_fmt}/{total_fmt} {rate_fmt} {remaining}" - - -def download_file(url, dest_path, headers=None, show_progress=True): - try: - from tqdm import tqdm - except FileNotFoundError: - raise NotImplementedError( - """\ -Ollama models requires the tqdm modules. -This module can be installed via PyPi tools like pip, pip3, pipx or via -distribution package managers like dnf or apt. Example: -pip install tqdm -""" - ) - - # Check if partially downloaded file exists - if os.path.exists(dest_path): - downloaded_size = os.path.getsize(dest_path) - else: - downloaded_size = 0 - - request = urllib.request.Request(url, headers=headers or {}) - request.headers["Range"] = f"bytes={downloaded_size}-" # Set range header - - try: - with urllib.request.urlopen(request) as response: - total_size = int(response.headers.get("Content-Length", 0)) + downloaded_size - chunk_size = 8192 # 8 KB chunks - - with open(dest_path, "ab") as file: - if show_progress: - with tqdm( - desc=dest_path[-16:], - total=total_size, - initial=downloaded_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - bar_format=bar_format, - ascii=True, - ) as progress_bar: - while True: - chunk = response.read(chunk_size) - if not chunk: - break - file.write(chunk) - progress_bar.update(len(chunk)) - else: - # Download file without showing progress - while True: - chunk = response.read(chunk_size) - if not chunk: - break - file.write(chunk) - except urllib.error.HTTPError as e: - if e.code == 416: - if show_progress: - # If we get a 416 error, it means the file is fully downloaded - print(f"File {url} already fully downloaded.") - else: - raise e - def fetch_manifest_data(registry_head, model_tag, accept): url = f"{registry_head}/manifests/{model_tag}"