Skip to content

Commit

Permalink
Replace huggingface-cli with a simple client to pull model over https
Browse files Browse the repository at this point in the history
Signed-off-by: swarajpande5 <[email protected]>
  • Loading branch information
swarajpande5 committed Oct 26, 2024
1 parent 4a8c454 commit 5749154
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 83 deletions.
170 changes: 88 additions & 82 deletions ramalama/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,102 @@
import os
from ramalama.common import run_cmd, exec_cmd
import urllib.request
from ramalama.common import run_cmd, exec_cmd, download_file, verify_checksum
from ramalama.model import Model

missing_huggingface = """
Huggingface models requires the huggingface-cli and tqdm modules.
These modules can be installed via PyPi tools like pip, pip3, pipx or via
Optional: Huggingface models require the huggingface-cli and tqdm modules.
These modules can be installed via PyPi tools like pip, pip3, pipx, or via
distribution package managers like dnf or apt. Example:
pip install huggingface_hub tqdm
"""


def is_huggingface_cli_available():
"""Check if huggingface-cli is available on the system."""
try:
run_cmd(["huggingface-cli", "version"])
return True
except FileNotFoundError:
print("huggingface-cli not found. Some features may be limited.\n" + missing_huggingface)
return False


def fetch_checksum_from_api(url):
"""Fetch the SHA-256 checksum from the model's metadata API."""
with urllib.request.urlopen(url) as response:
data = response.read().decode()
# Extract the SHA-256 checksum from the `oid sha256` line
for line in data.splitlines():
if line.startswith("oid sha256:"):
return line.split(":", 1)[1].strip()
raise ValueError("SHA-256 checksum not found in the API response.")


class Huggingface(Model):
def __init__(self, model):
model = model.removeprefix("huggingface://")
model = model.removeprefix("hf://")
super().__init__(model)
self.type = "HuggingFace"
split = self.model.rsplit("/", 1)
self.directory = ""
if len(split) > 1:
self.directory = split[0]
self.filename = split[1]
else:
self.filename = split[0]
self.directory = split[0] if len(split) > 1 else ""
self.filename = split[1] if len(split) > 1 else split[0]
self.hf_cli_available = is_huggingface_cli_available()

def login(self, args):
if not self.hf_cli_available:
print("huggingface-cli not available, skipping login.")
return
conman_args = ["huggingface-cli", "login"]
if args.token:
conman_args.extend(["--token", args.token])
try:
self.exec(conman_args)
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)
self.exec(conman_args)

def logout(self, args):
if not self.hf_cli_available:
print("huggingface-cli not available, skipping logout.")
return
conman_args = ["huggingface-cli", "logout"]
if args.token:
conman_args.extend(["--token", args.token])
conman_args.extend(args)
self.exec(conman_args)

def path(self, args):
return self.symlink_path(args)

def pull(self, args):
relative_target_path = ""
symlink_path = self.symlink_path(args)
directory_path = os.path.join(args.store, "repos", "huggingface", self.directory, self.filename)
os.makedirs(directory_path, exist_ok=True)

symlink_dir = os.path.dirname(symlink_path)
os.makedirs(symlink_dir, exist_ok=True)

# Fetch the SHA-256 checksum from the API
checksum_api_url = f"https://huggingface.co/{self.directory}/raw/main/{self.filename}"
sha256_checksum = fetch_checksum_from_api(checksum_api_url)

gguf_path = self.download(args.store)
relative_target_path = os.path.relpath(gguf_path.rstrip(), start=os.path.dirname(symlink_path))
directory = f"{args.store}/models/huggingface/{self.directory}"
os.makedirs(directory, exist_ok=True)
target_path = os.path.join(directory_path, f"sha256:{sha256_checksum}")

if os.path.exists(target_path) and verify_checksum(target_path):
relative_target_path = os.path.relpath(target_path, start=os.path.dirname(symlink_path))
if not self.check_valid_symlink_path(relative_target_path, symlink_path):
run_cmd(["ln", "-sf", relative_target_path, symlink_path], debug=args.debug)
return symlink_path

if os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path:
# Download the model file to the target path
url = f"https://huggingface.co/{self.directory}/resolve/main/{self.filename}"
download_file(url, target_path, headers={}, show_progress=True)

if not verify_checksum(target_path):
print(f"Checksum mismatch for {target_path}, retrying download...")
os.remove(target_path)
download_file(url, target_path, headers={}, show_progress=True)
if not verify_checksum(target_path):
raise ValueError(f"Checksum verification failed for {target_path}")

relative_target_path = os.path.relpath(target_path, start=os.path.dirname(symlink_path))
if self.check_valid_symlink_path(relative_target_path, symlink_path):
# Symlink is already correct, no need to update it
return symlink_path

Expand All @@ -66,67 +105,34 @@ def pull(self, args):
return symlink_path

def push(self, source, args):
try:
proc = run_cmd(
[
"huggingface-cli",
"upload",
"--repo-type",
"model",
self.directory,
self.filename,
"--cache-dir",
args.store + "/repos/huggingface/.cache",
"--local-dir",
args.store + "/repos/huggingface/" + self.directory,
],
debug=args.debug,
)
return proc.stdout.decode("utf-8")
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)
if not self.hf_cli_available:
print("huggingface-cli not available, skipping push.")
return
proc = run_cmd(
[
"huggingface-cli",
"upload",
"--repo-type",
"model",
self.directory,
self.filename,
"--cache-dir",
os.path.join(args.store, "repos", "huggingface", ".cache"),
"--local-dir",
os.path.join(args.store, "repos", "huggingface", self.directory),
],
debug=args.debug,
)
return proc.stdout.decode("utf-8")

def symlink_path(self, args):
return f"{args.store}/models/huggingface/{self.directory}/{self.filename}"
return os.path.join(args.store, "models", "huggingface", self.directory, self.filename)

def check_valid_symlink_path(self, relative_target_path, symlink_path):
return os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path

def exec(self, args):
try:
exec_cmd(args, args.debug)
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s
"""
% str(e).strip("'"),
missing_huggingface,
)

def download(self, store):
try:
proc = run_cmd(
[
"huggingface-cli",
"download",
self.directory,
self.filename,
"--cache-dir",
store + "/repos/huggingface/.cache",
"--local-dir",
store + "/repos/huggingface/" + self.directory,
]
)
return proc.stdout.decode("utf-8")
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)
print(f"{str(e).strip()}\n{missing_huggingface}")
2 changes: 1 addition & 1 deletion ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def garbage_collection(self, args):
file_has_a_symlink = False
for file in files:
file_path = os.path.join(root, file)
if (repo == "ollama" and file.startswith("sha256:")) or file.endswith(".gguf"):
if file.startswith("sha256:") or file.endswith(".gguf"):
file_path = os.path.join(root, file)
for model_root, model_dirs, model_files in os.walk(model_dir):
for model_file in model_files:
Expand Down

0 comments on commit 5749154

Please sign in to comment.