diff --git a/ramalama/huggingface.py b/ramalama/huggingface.py index ff49718e..af28d1a6 100644 --- a/ramalama/huggingface.py +++ b/ramalama/huggingface.py @@ -1,15 +1,37 @@ import os -from ramalama.common import run_cmd, exec_cmd +import urllib.request +from ramalama.common import run_cmd, exec_cmd, download_file, verify_checksum from ramalama.model import Model missing_huggingface = """ -Huggingface models requires the huggingface-cli and tqdm modules. -These modules can be installed via PyPi tools like pip, pip3, pipx or via +Optional: Huggingface models require the huggingface-cli and tqdm modules. +These modules can be installed via PyPi tools like pip, pip3, pipx, or via distribution package managers like dnf or apt. Example: pip install huggingface_hub tqdm """ +def is_huggingface_cli_available(): + """Check if huggingface-cli is available on the system.""" + try: + run_cmd(["huggingface-cli", "version"]) + return True + except FileNotFoundError: + print("huggingface-cli not found. Some features may be limited.\n" + missing_huggingface) + return False + + +def fetch_checksum_from_api(url): + """Fetch the SHA-256 checksum from the model's metadata API.""" + with urllib.request.urlopen(url) as response: + data = response.read().decode() + # Extract the SHA-256 checksum from the `oid sha256` line + for line in data.splitlines(): + if line.startswith("oid sha256:"): + return line.split(":", 1)[1].strip() + raise ValueError("SHA-256 checksum not found in the API response.") + + class Huggingface(Model): def __init__(self, model): model = model.removeprefix("huggingface://") @@ -17,47 +39,64 @@ def __init__(self, model): super().__init__(model) self.type = "HuggingFace" split = self.model.rsplit("/", 1) - self.directory = "" - if len(split) > 1: - self.directory = split[0] - self.filename = split[1] - else: - self.filename = split[0] + self.directory = split[0] if len(split) > 1 else "" + self.filename = split[1] if len(split) > 1 else split[0] + self.hf_cli_available = is_huggingface_cli_available() def login(self, args): + if not self.hf_cli_available: + print("huggingface-cli not available, skipping login.") + return conman_args = ["huggingface-cli", "login"] if args.token: conman_args.extend(["--token", args.token]) - try: - self.exec(conman_args) - except FileNotFoundError as e: - raise NotImplementedError( - """\ -%s -%s""" - % (str(e).strip("'"), missing_huggingface) - ) + self.exec(conman_args) def logout(self, args): + if not self.hf_cli_available: + print("huggingface-cli not available, skipping logout.") + return conman_args = ["huggingface-cli", "logout"] if args.token: conman_args.extend(["--token", args.token]) - conman_args.extend(args) self.exec(conman_args) def path(self, args): return self.symlink_path(args) def pull(self, args): - relative_target_path = "" symlink_path = self.symlink_path(args) + directory_path = os.path.join(args.store, "repos", "huggingface", self.directory, self.filename) + os.makedirs(directory_path, exist_ok=True) + + symlink_dir = os.path.dirname(symlink_path) + os.makedirs(symlink_dir, exist_ok=True) + + # Fetch the SHA-256 checksum from the API + checksum_api_url = f"https://huggingface.co/{self.directory}/raw/main/{self.filename}" + sha256_checksum = fetch_checksum_from_api(checksum_api_url) - gguf_path = self.download(args.store) - relative_target_path = os.path.relpath(gguf_path.rstrip(), start=os.path.dirname(symlink_path)) - directory = f"{args.store}/models/huggingface/{self.directory}" - os.makedirs(directory, exist_ok=True) + target_path = os.path.join(directory_path, f"sha256:{sha256_checksum}") + + if os.path.exists(target_path) and verify_checksum(target_path): + relative_target_path = os.path.relpath(target_path, start=os.path.dirname(symlink_path)) + if not self.check_valid_symlink_path(relative_target_path, symlink_path): + run_cmd(["ln", "-sf", relative_target_path, symlink_path], debug=args.debug) + return symlink_path - if os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path: + # Download the model file to the target path + url = f"https://huggingface.co/{self.directory}/resolve/main/{self.filename}" + download_file(url, target_path, headers={}, show_progress=True) + + if not verify_checksum(target_path): + print(f"Checksum mismatch for {target_path}, retrying download...") + os.remove(target_path) + download_file(url, target_path, headers={}, show_progress=True) + if not verify_checksum(target_path): + raise ValueError(f"Checksum verification failed for {target_path}") + + relative_target_path = os.path.relpath(target_path, start=os.path.dirname(symlink_path)) + if self.check_valid_symlink_path(relative_target_path, symlink_path): # Symlink is already correct, no need to update it return symlink_path @@ -66,67 +105,34 @@ def pull(self, args): return symlink_path def push(self, source, args): - try: - proc = run_cmd( - [ - "huggingface-cli", - "upload", - "--repo-type", - "model", - self.directory, - self.filename, - "--cache-dir", - args.store + "/repos/huggingface/.cache", - "--local-dir", - args.store + "/repos/huggingface/" + self.directory, - ], - debug=args.debug, - ) - return proc.stdout.decode("utf-8") - except FileNotFoundError as e: - raise NotImplementedError( - """\ - %s - %s""" - % (str(e).strip("'"), missing_huggingface) - ) + if not self.hf_cli_available: + print("huggingface-cli not available, skipping push.") + return + proc = run_cmd( + [ + "huggingface-cli", + "upload", + "--repo-type", + "model", + self.directory, + self.filename, + "--cache-dir", + os.path.join(args.store, "repos", "huggingface", ".cache"), + "--local-dir", + os.path.join(args.store, "repos", "huggingface", self.directory), + ], + debug=args.debug, + ) + return proc.stdout.decode("utf-8") def symlink_path(self, args): - return f"{args.store}/models/huggingface/{self.directory}/{self.filename}" + return os.path.join(args.store, "models", "huggingface", self.directory, self.filename) + + def check_valid_symlink_path(self, relative_target_path, symlink_path): + return os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path def exec(self, args): try: exec_cmd(args, args.debug) except FileNotFoundError as e: - raise NotImplementedError( - """\ -%s - -%s -""" - % str(e).strip("'"), - missing_huggingface, - ) - - def download(self, store): - try: - proc = run_cmd( - [ - "huggingface-cli", - "download", - self.directory, - self.filename, - "--cache-dir", - store + "/repos/huggingface/.cache", - "--local-dir", - store + "/repos/huggingface/" + self.directory, - ] - ) - return proc.stdout.decode("utf-8") - except FileNotFoundError as e: - raise NotImplementedError( - """\ - %s - %s""" - % (str(e).strip("'"), missing_huggingface) - ) + print(f"{str(e).strip()}\n{missing_huggingface}") diff --git a/ramalama/model.py b/ramalama/model.py index 1813bcf9..c80c28ed 100644 --- a/ramalama/model.py +++ b/ramalama/model.py @@ -63,7 +63,7 @@ def garbage_collection(self, args): file_has_a_symlink = False for file in files: file_path = os.path.join(root, file) - if (repo == "ollama" and file.startswith("sha256:")) or file.endswith(".gguf"): + if file.startswith("sha256:") or file.endswith(".gguf"): file_path = os.path.join(root, file) for model_root, model_dirs, model_files in os.walk(model_dir): for model_file in model_files: