Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ def _partition(self, param_list, force=False, has_been_updated=False, free_data=
for param in param_list:
print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
if self.zero_param_process_group is not None:
self._partition_param_sec(param)
self._partition_param_sec(param, has_been_updated=has_been_updated)
self._partition_param(param, has_been_updated=has_been_updated, free_data=True)

param.ds_status = ZeroParamStatus.NOT_AVAILABLE
Expand Down Expand Up @@ -1805,8 +1805,10 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
sec_numel = max(0, min(param.ds_numel - secondary_start, secondary_partition_size))

# copy from full tensor to secondary tensor
param.ds_secondary_tensor.narrow(0, 0,
sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel))
with torch.no_grad():
# make sure param.ds_secondary_tensor requires_grad always be false
param.ds_secondary_tensor.narrow(0, 0,
sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel))

# TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done
if not get_accelerator().resolves_data_dependency():
Expand Down
1 change: 1 addition & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3002,6 +3002,7 @@ def checkpoint_event_prologue(self):
self._partition_all_parameters()

def checkpoint_event_epilogue(self):
self.invalidate_secondary_tensor()
if len(self.persistent_parameters) > 0:
self.persistent_parameters[0].all_gather(self.persistent_parameters)

Expand Down
97 changes: 97 additions & 0 deletions tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from deepspeed.ops.op_builder import CPUAdamBuilder
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save, get_model_ckpt_name_for_rank
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero import ZeroParamStatus
from deepspeed.utils.torch import required_torch_version

from unit.common import DistributedTest, DistributedFixture
Expand Down Expand Up @@ -212,6 +213,33 @@ def run(self, class_tmpdir, elastic_save, load_optim):
model.save_checkpoint(class_tmpdir)


class ws4_model_checkpoint_zeropp(DistributedFixture):

world_size = 4

def run(self, class_tmpdir):
config_dict = {
"train_batch_size": 4,
"optimizer": {
"type": 'Adam'
},
"zero_optimization": {
"stage": 3,
"zero_hpz_partition_size": 2,
}
}

hidden_dim = 10
model = SimpleModel(hidden_dim)
for param in model.parameters():
param.data = torch.ones_like(param.data, device=param.data.device, requires_grad=False)

# save model and zero checkpoint
torch.save(model.state_dict(), os.path.join(class_tmpdir, "model.pt"))
ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None)
ds_model.save_checkpoint(class_tmpdir)


@pytest.mark.parametrize("elastic_save", [True, False])
@pytest.mark.parametrize("elastic_load", [True, False])
@pytest.mark.parametrize("load_optim", [True, False])
Expand Down Expand Up @@ -660,3 +688,72 @@ def mock_logger_info(message, *args, **kwargs):
engine._change_recovery_script_permissions(fake_recovery_script_dst)

assert log_called, "Expected deepspeed.utils.logger.info to be called."


class TestZeROPPLoadCheckpoint(DistributedTest):

world_size = 4

def test_load_zeropp_model(self, ws4_model_checkpoint_zeropp, class_tmpdir):
config_dict = {
"train_batch_size": 4,
"optimizer": {
"type": 'Adam'
},
"zero_optimization": {
"stage": 3,
"zero_hpz_partition_size": 2,
"stage3_param_persistence_threshold": 1
}
}

# Init model and load saved model
hidden_dim = 10
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim)
ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None)

with deepspeed.zero.GatheredParameters(ds_model.module.parameters(), modifier_rank=0):
if dist.get_rank() == 0:
state_dict = torch.load(os.path.join(class_tmpdir, "model.pt"))
ds_model.module.load_state_dict(state_dict)

# Check the parameters after gather
params_to_gather = [p for p in ds_model.module.parameters() if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]
if len(params_to_gather) > 0:
handle = params_to_gather[0].all_gather_coalesced(params_to_gather)
handle.wait()
for ds_param in params_to_gather:
for v in ds_param.data.cpu().flatten().numpy():
assert v == 1.0

def test_load_zeropp_checkpoint(self, ws4_model_checkpoint_zeropp, class_tmpdir):
config_dict = {
"train_batch_size": 4,
"optimizer": {
"type": 'Adam'
},
"zero_optimization": {
"stage": 3,
"zero_hpz_partition_size": 2,
"stage3_param_persistence_threshold": 1
}
}

# Init model and load zero checkpoint
hidden_dim = 10
model = SimpleModel(hidden_dim)
ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None)
ds_model.load_checkpoint(class_tmpdir,
load_optimizer_states=True,
load_lr_scheduler_states=False,
load_module_only=False)

# Check the parameters after gather
params_to_gather = [p for p in ds_model.module.parameters() if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]
if len(params_to_gather) > 0:
handle = params_to_gather[0].all_gather_coalesced(params_to_gather)
handle.wait()
for ds_param in params_to_gather:
for v in ds_param.data.cpu().flatten().numpy():
assert v == 1.0