From 0d8b8a1828eb00d40bb38018cca017ba3b13bc8b Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 7 Nov 2024 15:03:40 +0100 Subject: [PATCH] Fixed failing tests on pytorch nightly using torch.load (#3299) --- tests/ignite/engine/test_deterministic.py | 7 ++++++- tests/ignite/handlers/test_state_param_scheduler.py | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/ignite/engine/test_deterministic.py b/tests/ignite/engine/test_deterministic.py index b2f62dfa111..9af11695622 100644 --- a/tests/ignite/engine/test_deterministic.py +++ b/tests/ignite/engine/test_deterministic.py @@ -8,6 +8,7 @@ import pytest import torch import torch.nn as nn +from packaging.version import Version from torch.optim import SGD from torch.utils.data import BatchSampler, DataLoader, RandomSampler @@ -737,7 +738,11 @@ def write_data_grads_weights(e): grad_norms.append([i, total[1]] + out2) if sd is not None: - sd = torch.load(sd) + if Version(torch.__version__) >= Version("1.13.0"): + kwargs = {"weights_only": False} + else: + kwargs = {} + sd = torch.load(sd, **kwargs) model.load_state_dict(sd[0]) opt.load_state_dict(sd[1]) from ignite.engine.deterministic import _repr_rng_state diff --git a/tests/ignite/handlers/test_state_param_scheduler.py b/tests/ignite/handlers/test_state_param_scheduler.py index b907683d7e0..ad79eda51b0 100644 --- a/tests/ignite/handlers/test_state_param_scheduler.py +++ b/tests/ignite/handlers/test_state_param_scheduler.py @@ -295,7 +295,11 @@ def test_torch_save_load(dirname): filepath = Path(dirname) / "dummy_lambda_state_parameter_scheduler.pt" torch.save(lambda_state_parameter_scheduler, filepath) - loaded_lambda_state_parameter_scheduler = torch.load(filepath) + if Version(torch.__version__) >= Version("1.13.0"): + kwargs = {"weights_only": False} + else: + kwargs = {} + loaded_lambda_state_parameter_scheduler = torch.load(filepath, **kwargs) engine1 = Engine(lambda e, b: None) lambda_state_parameter_scheduler.attach(engine1, Events.EPOCH_COMPLETED)