diff --git a/python/tvm/contrib/pickle_memoize.py b/python/tvm/contrib/pickle_memoize.py index f9981a5ea4ce..64a88068ecfb 100644 --- a/python/tvm/contrib/pickle_memoize.py +++ b/python/tvm/contrib/pickle_memoize.py @@ -23,6 +23,7 @@ import os import pathlib import sys +import warnings try: import cPickle as pickle @@ -71,6 +72,13 @@ def cache(self): if self.path.exists(): with self.path.open("rb") as cache_file: try: + warnings.warn( + f"Loading cached pickle file from {self.path}. " + "Pickle files can execute arbitrary code. " + "Only load cache files you trust.", + UserWarning, + stacklevel=2, + ) cache = pickle.load(cache_file) except pickle.UnpicklingError: cache = {} diff --git a/tests/python/contrib/test_pickle_memoize_warning.py b/tests/python/contrib/test_pickle_memoize_warning.py new file mode 100644 index 000000000000..3f26961d204e --- /dev/null +++ b/tests/python/contrib/test_pickle_memoize_warning.py @@ -0,0 +1,26 @@ +import pytest +import pickle +import tempfile +import os + + +def test_pickle_memoize_warns_on_cache_load(): + """Test that loading a cached pickle file emits a UserWarning.""" + from tvm.contrib.pickle_memoize import memoize + + # Create a cache file + with tempfile.TemporaryDirectory() as tmpdir: + cache_path = os.path.join(tmpdir, "test_cache") + + @memoize("test_warning_cache") + def dummy_func(): + return 42 + + # First call creates cache + result = dummy_func() + assert result == 42 + + # Second call loads from cache — should warn + with pytest.warns(UserWarning, match="Pickle files can execute arbitrary code"): + result2 = dummy_func() + assert result2 == 42