Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/tvm/contrib/pickle_memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import pathlib
import sys
import warnings

try:
import cPickle as pickle
Expand Down Expand Up @@ -71,6 +72,13 @@ def cache(self):
if self.path.exists():
with self.path.open("rb") as cache_file:
try:
warnings.warn(
f"Loading cached pickle file from {self.path}. "
"Pickle files can execute arbitrary code. "
"Only load cache files you trust.",
UserWarning,
stacklevel=2,
)
Comment on lines +75 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change introduces a user warning, which is great for security awareness. However, there is no test to verify that this warning is actually triggered when loading from a cached pickle file. Please consider adding a test case to tests/python/contrib/test_memoize.py that uses pytest.warns to assert that the UserWarning is raised.

cache = pickle.load(cache_file)
except pickle.UnpicklingError:
cache = {}
Expand Down
26 changes: 26 additions & 0 deletions tests/python/contrib/test_pickle_memoize_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import pickle
import tempfile
import os


def test_pickle_memoize_warns_on_cache_load():
"""Test that loading a cached pickle file emits a UserWarning."""
from tvm.contrib.pickle_memoize import memoize

# Create a cache file
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = os.path.join(tmpdir, "test_cache")

@memoize("test_warning_cache")
def dummy_func():
return 42

# First call creates cache
result = dummy_func()
assert result == 42

# Second call loads from cache — should warn
with pytest.warns(UserWarning, match="Pickle files can execute arbitrary code"):
result2 = dummy_func()
assert result2 == 42