Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,24 @@ def get_checkpoint_shard_files(
index = json.loads(f.read())

shard_filenames = sorted(set(index["weight_map"].values()))
# The shard filename comes from an attacker-controlled `index.json` `weight_map` value
# when the repo is untrusted. Each entry must be a single relative basename (no path
# separators, no absolute path, no `..`) so `os.path.join` below cannot be tricked into
# loading a file outside the checkpoint directory (which would otherwise let a malicious
# repo silently load an arbitrary local/Hub-cache `.safetensors`/`.bin` file as model
# weights, bypassing the `use_safetensors`, `weights_only`, and `trust_remote_code`
# gates that the rest of the loader relies on).
for shard_filename in shard_filenames:
if (
shard_filename in ("", ".", "..")
or os.path.isabs(shard_filename)
or shard_filename != os.path.basename(shard_filename)
):
raise ValueError(
f"Invalid shard filename in checkpoint index {index_filename!r}: {shard_filename!r}. "
"Shard filenames must be a relative basename (no path separators, no absolute path, "
"no '..')."
)
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()
Expand Down
60 changes: 60 additions & 0 deletions tests/utils/test_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,63 @@ def test_list_repo_templates_w_offline(self):
side_effect=LocalEntryNotFoundError("no snapshot found"),
):
self.assertEqual(list_repo_templates(RANDOM_BERT, local_files_only=False), [])


class GetCheckpointShardFilesSecurityTests(unittest.TestCase):
"""Regression tests for path traversal via `model.safetensors.index.json` `weight_map` values.

A malicious Hub repo can ship an `index.json` whose `weight_map` values are absolute paths
or contain `..` components. Because the loader builds shard paths with `os.path.join(repo, subfolder, f)`,
an attacker-controlled malicious value used to escape the checkpoint directory and load an
arbitrary `.safetensors`/`.bin` file as model weights, bypassing `use_safetensors`,
`weights_only`, and `trust_remote_code`. The loader now rejects any non-basename shard filename.
"""

def _write_index(self, dir_path: str, shard_filename: str) -> str:
index_path = os.path.join(dir_path, "model.safetensors.index.json")
with open(index_path, "w") as f:
json.dump(
{"metadata": {"total_size": 0}, "weight_map": {"a.weight": shard_filename}},
f,
)
return index_path

def _call(self, d: str, index_path: str):
from transformers.utils.hub import get_checkpoint_shard_files

return get_checkpoint_shard_files(
d,
index_path,
local_files_only=True,
cache_dir=None,
force_download=False,
proxies=None,
local_files_only_enabled=None,
token=None,
revision="main",
subfolder="",
user_agent=None,
_commit_hash=None,
tqdm_class=None,
)

def test_rejects_absolute_path_in_weight_map(self):
with tempfile.TemporaryDirectory() as d:
index_path = self._write_index(d, os.path.abspath(os.path.join(d, "model.safetensors")))
with self.assertRaisesRegex(ValueError, "Invalid shard filename"):
self._call(d, index_path)

def test_rejects_parent_traversal_in_weight_map(self):
with tempfile.TemporaryDirectory() as d:
index_path = self._write_index(d, "../../etc/passwd")
with self.assertRaisesRegex(ValueError, "Invalid shard filename"):
self._call(d, index_path)

def test_accepts_benign_relative_basename(self):
"""Sanity check: a normal relative `model-00001-of-00001.safetensors` value still loads."""
with tempfile.TemporaryDirectory() as d:
shard_name = "model-00001-of-00001.safetensors"
index_path = self._write_index(d, shard_name)
open(os.path.join(d, shard_name), "wb").close()
shards, _ = self._call(d, index_path)
self.assertEqual(shards, [os.path.join(d, shard_name)])