diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 542f1175d8f0..dbb4e72a9556 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -880,6 +880,24 @@ def get_checkpoint_shard_files( index = json.loads(f.read()) shard_filenames = sorted(set(index["weight_map"].values())) + # The shard filename comes from an attacker-controlled `index.json` `weight_map` value + # when the repo is untrusted. Each entry must be a single relative basename (no path + # separators, no absolute path, no `..`) so `os.path.join` below cannot be tricked into + # loading a file outside the checkpoint directory (which would otherwise let a malicious + # repo silently load an arbitrary local/Hub-cache `.safetensors`/`.bin` file as model + # weights, bypassing the `use_safetensors`, `weights_only`, and `trust_remote_code` + # gates that the rest of the loader relies on). + for shard_filename in shard_filenames: + if ( + shard_filename in ("", ".", "..") + or os.path.isabs(shard_filename) + or shard_filename != os.path.basename(shard_filename) + ): + raise ValueError( + f"Invalid shard filename in checkpoint index {index_filename!r}: {shard_filename!r}. " + "Shard filenames must be a relative basename (no path separators, no absolute path, " + "no '..')." + ) sharded_metadata = index["metadata"] sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) sharded_metadata["weight_map"] = index["weight_map"].copy() diff --git a/tests/utils/test_hub_utils.py b/tests/utils/test_hub_utils.py index 3edf27ca6312..b3936b9b6a16 100644 --- a/tests/utils/test_hub_utils.py +++ b/tests/utils/test_hub_utils.py @@ -205,3 +205,63 @@ def test_list_repo_templates_w_offline(self): side_effect=LocalEntryNotFoundError("no snapshot found"), ): self.assertEqual(list_repo_templates(RANDOM_BERT, local_files_only=False), []) + + +class GetCheckpointShardFilesSecurityTests(unittest.TestCase): + """Regression tests for path traversal via `model.safetensors.index.json` `weight_map` values. + + A malicious Hub repo can ship an `index.json` whose `weight_map` values are absolute paths + or contain `..` components. Because the loader builds shard paths with `os.path.join(repo, subfolder, f)`, + an attacker-controlled malicious value used to escape the checkpoint directory and load an + arbitrary `.safetensors`/`.bin` file as model weights, bypassing `use_safetensors`, + `weights_only`, and `trust_remote_code`. The loader now rejects any non-basename shard filename. + """ + + def _write_index(self, dir_path: str, shard_filename: str) -> str: + index_path = os.path.join(dir_path, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump( + {"metadata": {"total_size": 0}, "weight_map": {"a.weight": shard_filename}}, + f, + ) + return index_path + + def _call(self, d: str, index_path: str): + from transformers.utils.hub import get_checkpoint_shard_files + + return get_checkpoint_shard_files( + d, + index_path, + local_files_only=True, + cache_dir=None, + force_download=False, + proxies=None, + local_files_only_enabled=None, + token=None, + revision="main", + subfolder="", + user_agent=None, + _commit_hash=None, + tqdm_class=None, + ) + + def test_rejects_absolute_path_in_weight_map(self): + with tempfile.TemporaryDirectory() as d: + index_path = self._write_index(d, os.path.abspath(os.path.join(d, "model.safetensors"))) + with self.assertRaisesRegex(ValueError, "Invalid shard filename"): + self._call(d, index_path) + + def test_rejects_parent_traversal_in_weight_map(self): + with tempfile.TemporaryDirectory() as d: + index_path = self._write_index(d, "../../etc/passwd") + with self.assertRaisesRegex(ValueError, "Invalid shard filename"): + self._call(d, index_path) + + def test_accepts_benign_relative_basename(self): + """Sanity check: a normal relative `model-00001-of-00001.safetensors` value still loads.""" + with tempfile.TemporaryDirectory() as d: + shard_name = "model-00001-of-00001.safetensors" + index_path = self._write_index(d, shard_name) + open(os.path.join(d, shard_name), "wb").close() + shards, _ = self._call(d, index_path) + self.assertEqual(shards, [os.path.join(d, shard_name)])