diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 9d168a67a1f9..5c54446111c2 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1632,7 +1632,7 @@ def _partition(self, param_list, force=False, has_been_updated=False, free_data= for param in param_list: print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) if self.zero_param_process_group is not None: - self._partition_param_sec(param) + self._partition_param_sec(param, has_been_updated=has_been_updated) self._partition_param(param, has_been_updated=has_been_updated, free_data=True) param.ds_status = ZeroParamStatus.NOT_AVAILABLE @@ -1805,8 +1805,10 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False): sec_numel = max(0, min(param.ds_numel - secondary_start, secondary_partition_size)) # copy from full tensor to secondary tensor - param.ds_secondary_tensor.narrow(0, 0, - sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel)) + with torch.no_grad(): + # make sure param.ds_secondary_tensor requires_grad always be false + param.ds_secondary_tensor.narrow(0, 0, + sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel)) # TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done if not get_accelerator().resolves_data_dependency(): diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 50c081b3dc68..c28bfb8da291 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -3002,6 +3002,7 @@ def checkpoint_event_prologue(self): self._partition_all_parameters() def checkpoint_event_epilogue(self): + self.invalidate_secondary_tensor() if len(self.persistent_parameters) > 0: self.persistent_parameters[0].all_gather(self.persistent_parameters) diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index 85c38d7f5ffd..cca92348fc62 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -8,6 +8,7 @@ from deepspeed.ops.op_builder import CPUAdamBuilder from deepspeed.checkpoint.utils import clone_tensors_for_torch_save, get_model_ckpt_name_for_rank from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero import ZeroParamStatus from deepspeed.utils.torch import required_torch_version from unit.common import DistributedTest, DistributedFixture @@ -212,6 +213,33 @@ def run(self, class_tmpdir, elastic_save, load_optim): model.save_checkpoint(class_tmpdir) +class ws4_model_checkpoint_zeropp(DistributedFixture): + + world_size = 4 + + def run(self, class_tmpdir): + config_dict = { + "train_batch_size": 4, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": 3, + "zero_hpz_partition_size": 2, + } + } + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + for param in model.parameters(): + param.data = torch.ones_like(param.data, device=param.data.device, requires_grad=False) + + # save model and zero checkpoint + torch.save(model.state_dict(), os.path.join(class_tmpdir, "model.pt")) + ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) + ds_model.save_checkpoint(class_tmpdir) + + @pytest.mark.parametrize("elastic_save", [True, False]) @pytest.mark.parametrize("elastic_load", [True, False]) @pytest.mark.parametrize("load_optim", [True, False]) @@ -660,3 +688,72 @@ def mock_logger_info(message, *args, **kwargs): engine._change_recovery_script_permissions(fake_recovery_script_dst) assert log_called, "Expected deepspeed.utils.logger.info to be called." + + +class TestZeROPPLoadCheckpoint(DistributedTest): + + world_size = 4 + + def test_load_zeropp_model(self, ws4_model_checkpoint_zeropp, class_tmpdir): + config_dict = { + "train_batch_size": 4, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": 3, + "zero_hpz_partition_size": 2, + "stage3_param_persistence_threshold": 1 + } + } + + # Init model and load saved model + hidden_dim = 10 + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim) + ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) + + with deepspeed.zero.GatheredParameters(ds_model.module.parameters(), modifier_rank=0): + if dist.get_rank() == 0: + state_dict = torch.load(os.path.join(class_tmpdir, "model.pt")) + ds_model.module.load_state_dict(state_dict) + + # Check the parameters after gather + params_to_gather = [p for p in ds_model.module.parameters() if p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + if len(params_to_gather) > 0: + handle = params_to_gather[0].all_gather_coalesced(params_to_gather) + handle.wait() + for ds_param in params_to_gather: + for v in ds_param.data.cpu().flatten().numpy(): + assert v == 1.0 + + def test_load_zeropp_checkpoint(self, ws4_model_checkpoint_zeropp, class_tmpdir): + config_dict = { + "train_batch_size": 4, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": 3, + "zero_hpz_partition_size": 2, + "stage3_param_persistence_threshold": 1 + } + } + + # Init model and load zero checkpoint + hidden_dim = 10 + model = SimpleModel(hidden_dim) + ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) + ds_model.load_checkpoint(class_tmpdir, + load_optimizer_states=True, + load_lr_scheduler_states=False, + load_module_only=False) + + # Check the parameters after gather + params_to_gather = [p for p in ds_model.module.parameters() if p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + if len(params_to_gather) > 0: + handle = params_to_gather[0].all_gather_coalesced(params_to_gather) + handle.wait() + for ds_param in params_to_gather: + for v in ds_param.data.cpu().flatten().numpy(): + assert v == 1.0