From eb21d9dd349a6ae1a28c440b30d306eafba65097 Mon Sep 17 00:00:00 2001 From: AznamirWoW <101997116+AznamirWoW@users.noreply.github.com> Date: Sat, 1 Feb 2025 08:30:17 -0500 Subject: [PATCH] more torch.load fixes --- rvc/train/train.py | 2 +- tabs/inference/inference.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rvc/train/train.py b/rvc/train/train.py index 37297c62..9dd58fa6 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -134,7 +134,7 @@ def record(self): def verify_checkpoint_shapes(checkpoint_path, model): - checkpoint = torch.load(checkpoint_path, map_location="cpu") + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) checkpoint_state_dict = checkpoint["model"] try: if hasattr(model, "module"): diff --git a/tabs/inference/inference.py b/tabs/inference/inference.py index f6fac000..589503f3 100644 --- a/tabs/inference/inference.py +++ b/tabs/inference/inference.py @@ -334,7 +334,7 @@ def refresh_embedders_folders(): def get_speakers_id(model): if model: try: - model_data = torch.load(os.path.join(now_dir, model), map_location="cpu") + model_data = torch.load(os.path.join(now_dir, model), map_location="cpu", weights_only=True) speakers_id = model_data.get("speakers_id") if speakers_id: return list(range(speakers_id))