Skip to content

Disable torch.load in TorchModuleWrapper when in safe mode. #21575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras/src/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.layers import Layer
from keras.src.ops import convert_to_numpy
from keras.src.ops import convert_to_tensor
from keras.src.saving.serialization_lib import in_safe_mode


@keras_export("keras.layers.TorchModuleWrapper")
Expand Down Expand Up @@ -166,6 +167,17 @@ def from_config(cls, config):
import torch

if "module" in config:
if in_safe_mode():
raise ValueError(
"Requested the deserialization of a `torch.nn.Module` "
"object via `torch.load()`. This carries a potential risk "
"of arbitrary code execution and thus it is disallowed by "
"default. If you trust the source of the saved model, you "
"can pass `safe_mode=False` to the loading function in "
"order to allow `torch.nn.Module` loading, or call "
"`keras.config.enable_unsafe_deserialization()`."
)

# Decode the base64 string back to bytes
buffer_bytes = base64.b64decode(config["module"].encode("ascii"))
buffer = io.BytesIO(buffer_bytes)
Expand Down
42 changes: 30 additions & 12 deletions keras/src/utils/torch_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,26 +248,44 @@ def test_build_model(self):
self.assertEqual(model.predict(np.zeros([5, 4])).shape, (5, 16))
self.assertEqual(model(np.zeros([5, 4])).shape, (5, 16))

def test_save_load(self):
@parameterized.named_parameters(
("safe_mode", True),
("unsafe_mode", False),
)
def test_save_load(self, safe_mode):
@keras.saving.register_keras_serializable()
class M(keras.Model):
def __init__(self, channels=10, **kwargs):
super().__init__()
self.sequence = torch.nn.Sequential(
torch.nn.Conv2d(1, channels, kernel_size=(3, 3)),
)
def __init__(self, module, **kwargs):
super().__init__(**kwargs)
self.module = module

def call(self, x):
return self.sequence(x)
return self.module(x)

m = M()
def get_config(self):
base_config = super().get_config()
config = {"module": self.module}
return {**base_config, **config}

@classmethod
def from_config(cls, config):
config["module"] = saving.deserialize_keras_object(
config["module"]
)
return cls(**config)

m = M(torch.nn.Conv2d(1, 10, kernel_size=(3, 3)))
device = get_device() # Get the current device (e.g., "cuda" or "cpu")
x = torch.ones(
(10, 1, 28, 28), device=device
) # Place input on the correct device
m(x)
ref_output = m(x)
temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras")
m.save(temp_filepath)
new_model = saving.load_model(temp_filepath)
for ref_w, new_w in zip(m.get_weights(), new_model.get_weights()):
self.assertAllClose(ref_w, new_w, atol=1e-5)

if safe_mode:
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
saving.load_model(temp_filepath, safe_mode=safe_mode)
else:
new_model = saving.load_model(temp_filepath, safe_mode=safe_mode)
self.assertAllClose(new_model(x), ref_output)