Skip to content

Commit 8927a43

Browse files
author
amaurya
committed
Update to Apache-2.0 License, move debloating to checkpointing engine
Signed-off-by: amaurya <[email protected]>
1 parent bb16750 commit 8927a43

File tree

7 files changed

+10
-20
lines changed

7 files changed

+10
-20
lines changed

deepspeed/datastates/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
# MIT License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
55

66
# DeepSpeed Team

deepspeed/datastates/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
# MIT License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
55

66
# DeepSpeed Team
77

deepspeed/runtime/checkpoint_engine/checkpoint_engine.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,3 @@ def commit(self, tag):
3232
def wait(self):
3333
# To wait in asynchronous checkpoint engines (e.g. DataStates-LLM) for the previous snapshot to finish
3434
pass
35-
36-
def preserves_storage_sharing(self):
37-
# Check if the checkpoint engine preserves storage sharing
38-
# (set to false if cloning is required to get actual tensor sizes)
39-
return False

deepspeed/runtime/checkpoint_engine/datastates_checkpoint_engine.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
# MIT License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
55

66
# DeepSpeed Team
77

@@ -32,6 +32,3 @@ def commit(self, tag):
3232

3333
def wait(self):
3434
return self.ckpt_engine.wait()
35-
36-
def preserves_storage_sharing(self):
37-
return True

deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
CheckpointEngine
1212
from deepspeed.utils import logger, log_dist
1313
from deepspeed.nebula.constants import *
14+
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
1415

1516

1617
def _get_tag_from_path(path):
@@ -48,7 +49,8 @@ def save(self, state_dict, path: str):
4849
tag = _get_tag_from_path(path)
4950
partition_name = os.path.basename(path)
5051
logger.info(f"[Nebula] Saving {partition_name} under tag {tag}...")
51-
self.checkpoint.save(partition_name, state_dict)
52+
debloated_state_dict = clone_tensors_for_torch_save(state_dict)
53+
self.checkpoint.save(partition_name, debloated_state_dict)
5254
logger.info(f"[Nebula] Saved {partition_name} under tag {tag}.")
5355
return None
5456

deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from deepspeed.utils import logger, log_dist
88
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
99
CheckpointEngine
10+
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
1011

1112

1213
class TorchCheckpointEngine(CheckpointEngine):
@@ -19,7 +20,8 @@ def create(self, tag):
1920

2021
def save(self, state_dict, path: str):
2122
logger.info(f"[Torch] Saving {path}...")
22-
torch.save(state_dict, path)
23+
debloated_state_dict = clone_tensors_for_torch_save(state_dict)
24+
torch.save(debloated_state_dict, path)
2325
logger.info(f"[Torch] Saved {path}.")
2426
return None
2527

deepspeed/runtime/pipe/module.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from .topology import PipeDataParallelTopology, PipelineParallelGrid
2121
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
2222
from deepspeed.accelerator import get_accelerator
23-
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
2423

2524

2625
class PipelineError(Exception):
@@ -629,12 +628,7 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
629628
for n in self._get_frozen_parameter_names(layer):
630629
del orig_state_dict[n]
631630

632-
if checkpoint_engine.preserves_storage_sharing():
633-
final_state_dict = orig_state_dict
634-
else:
635-
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
636-
637-
checkpoint_engine.save(final_state_dict, model_ckpt_path)
631+
checkpoint_engine.save(orig_state_dict, model_ckpt_path)
638632

639633
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
640634
for idx, layer in enumerate(self.forward_funcs):

0 commit comments

Comments
 (0)