diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py index 5cd44864bb2e..6476863037f2 100644 --- a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -19,7 +19,10 @@ def create(self, tag): def save(self, state_dict, path: str): logger.info(f"[Torch] Saving {path}...") - torch.save(state_dict, path) + if path.contains("s3://"): + pass + else: + torch.save(state_dict, path) logger.info(f"[Torch] Saved {path}.") return None