Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions deepspeed/datastates/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# DataStates-LLM checkpointing engine.

This feature is not enabled by default. To enable, set the following options in ds_config.json and download the [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md).

```
{
... other deepspeed config options,
"datastates_ckpt": {
"host_cache_size": 16
}
}
```
6 changes: 6 additions & 0 deletions deepspeed/datastates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.

# DeepSpeed Team
21 changes: 21 additions & 0 deletions deepspeed/datastates/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.

# DeepSpeed Team

from deepspeed.runtime.config_utils import DeepSpeedConfigObject
import copy

DATASTATES_CHECKPOINTING = "datastates_ckpt"
DATASTATES_CHECKPOINTING_ENABLED = False


class DeepSpeedDataStatesConfig(DeepSpeedConfigObject):

def __init__(self, param_dict):
super(DeepSpeedDataStatesConfig, self).__init__()

self.enabled = param_dict.get(DATASTATES_CHECKPOINTING, DATASTATES_CHECKPOINTING_ENABLED) is not False
self.config = copy.deepcopy(param_dict.get(DATASTATES_CHECKPOINTING, None))
3 changes: 3 additions & 0 deletions deepspeed/runtime/checkpoint_engine/checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,6 @@ def get_commit_info(self):

def cleanup(self):
pass

def preserves_storage_sharing(self):
return True
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.

# DeepSpeed Team

from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
CheckpointEngine, CheckpointCommitInfo
from datastates import CheckpointEngine as DataStatesEngine

ENGINE_NAME = "DataStatesCheckpointEngine"


class DataStatesCheckpointEngine(CheckpointEngine):

def __init__(self, deepspeed_config, rank):
super().__init__(deepspeed_config)
self.commit_info = None
self.ckpt_engine = DataStatesEngine(deepspeed_config, rank)

def __del__(self):
self.cleanup()

def create(self, info: CheckpointCommitInfo):
self.commit_info = info
return None

def save(self, state_dict, path: str):
return self.ckpt_engine.save(state_dict, path)

def load(self, path: str, map_location=None):
return self.ckpt_engine.load(path, map_location)

def commit(self, info: CheckpointCommitInfo):
if info is None:
return
assert info == self.commit_info
self.ckpt_engine.wait(persist=True)
return self.ckpt_engine.commit(info.tag)

def cleanup(self):
self.commit(self.commit_info)
self.ckpt_engine.wait(persist=True)
del self.ckpt_engine

def is_decoupled(self):
return True

def preserves_storage_sharing(self):
return False
12 changes: 11 additions & 1 deletion deepspeed/runtime/checkpoint_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from deepspeed.runtime.model_checkpointing.constants import *
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config
from deepspeed.utils import logger

from deepspeed import comm as dist
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
from .fast_checkpoint_engine import FastCheckpointEngine
from .torch_checkpoint_engine import TorchCheckpointEngine
Expand Down Expand Up @@ -35,4 +35,14 @@ def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers,
else:
return NebulaCheckpointEngine(config_params=config_params.nebula_config)

if config_params.datastates_config.enabled:
try:
from deepspeed.runtime.checkpoint_engine.datastates_checkpoint_engine import DataStatesCheckpointEngine
return DataStatesCheckpointEngine(deepspeed_config=config_params, rank=dist.get_rank())
except ImportError as err:
logger.error(
f"No datastates engine found! Install from https://github.com/DataStates/datastates-llm. Will fall back to torch.save. Details: {err}"
)
return TorchCheckpointEngine(config_params)

return TorchCheckpointEngine(config_params)
2 changes: 2 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from ..profiling.config import DeepSpeedFlopsProfilerConfig
from ..autotuning.config import DeepSpeedAutotuningConfig
from ..nebula.config import DeepSpeedNebulaConfig
from ..datastates.config import DeepSpeedDataStatesConfig

from ..compression.config import get_compression_config, get_quantize_enabled
from ..compression.constants import *
Expand Down Expand Up @@ -859,6 +860,7 @@ def _initialize_params(self, param_dict):
self.dataloader_drop_last = get_dataloader_drop_last(param_dict)

self.nebula_config = DeepSpeedNebulaConfig(param_dict)
self.datastates_config = DeepSpeedDataStatesConfig(param_dict)
self.checkpoint_config = get_checkpoint_config(param_dict)

self.weight_quantization_config = WeightQuantConfig(
Expand Down
12 changes: 9 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3594,7 +3594,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
if self.random_ltd_enabled():
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
saveable_state_dict = expert_state_dict
if self.checkpoint_engine.preserves_storage_sharing():
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
self.checkpoint_engine.save(saveable_state_dict, moe_save_path)
moe_layer_id += 1

Expand All @@ -3616,7 +3618,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
}
# TODO: why use BufferedWriter not the path
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
saveable_state_dict = optimizer_state
if self.checkpoint_engine.preserves_storage_sharing():
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
self.checkpoint_engine.save(saveable_state_dict, file_path)

# Load flow uses below saved file for model parameters, RNG and more
Expand Down Expand Up @@ -3656,7 +3660,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
}
state.update(client_state)
logger.info(f'Saving model checkpoint: {save_path}')
saveable_state_dict = clone_tensors_for_torch_save(state)
savable_state_dict = state
if self.checkpoint_engine.preserves_storage_sharing():
saveable_state_dict = clone_tensors_for_torch_save(state)
self.checkpoint_engine.save(saveable_state_dict, save_path)

def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
layer_list = self.forward_funcs[start:end]

checkpoint_engine.makedirs(save_dir, exist_ok=True)
should_clone = checkpoint_engine.preserves_storage_sharing()
for idx, layer in enumerate(layer_list):
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
if not hasattr(layer, 'state_dict'):
Expand All @@ -630,7 +631,9 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
if exclude_frozen_params:
for n in self._get_frozen_parameter_names(layer):
del orig_state_dict[n]
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
final_state_dict = orig_state_dict
if should_clone:
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
checkpoint_engine.save(state_dict=final_state_dict, path=model_ckpt_path)

def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
Expand Down
55 changes: 55 additions & 0 deletions docs/_tutorials/datastates-async-checkpointing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
---
title: "DataStates-LLM Checkpointing Engine"
tags: asynchronous checkpointing for minimizing I/O overheads.
---
This tutorial will show how to use [DataStates-LLM](https://github.com/DataStates/datastates-llm) for asynchronous checkpointing. DataStates-LLM introduces a lazy asynchronous checkpointing mechanism tailored for LLMs, aiming to minimize I/O overhead and enhance training efficiency. This tutorial provides a guide on integrating DataStates-LLM with the DeepSpeed framework.

## Overview of DataStates-LLM

DataStates-LLM is designed to address the challenges of frequent checkpointing in LLM training by introducing a lazy asynchronous multi-level approach. It leverages the immutability of model parameters and optimizer states during forward and backward passes to perform non-blocking data transfers, thereby reducing interference with the training process. This method has demonstrated up to 48x faster checkpointing and 2.2x faster end-to-end training times compared to traditional approaches as outlined in [DataStates-LLM: Lazy Asynchronous Checkpointing for Large Language Models](https://arxiv.org/abs/2406.10707).

## Prerequisites

Before integrating DataStates-LLM with DeepSpeed, ensure the following:

- **DeepSpeed Installation**: DeepSpeed should be installed in your environment. If not, refer to the [DeepSpeed Getting Started Guide](https://github.com/microsoft/DeepSpeed/blob/master/docs/_tutorials/getting-started.md) for installation instructions.

- **DataStates-LLM Repository**: Access the DataStates-LLM source code from its [GitHub repository](https://github.com/DataStates/datastates-llm) and follow the installation instructions provided therein.

## Configuring DeepSpeed for DataStates-LLM

To enable DataStates-LLM's asynchronous checkpointing within DeepSpeed, please modify the `deepspeed_config.json` file to include specific configurations under the `datastates_ckpt` section. Below is an example configuration:

```json
{
// ... other DeepSpeed configuration options
"datastates_ckpt": {
"host_cache_size": 16
}
}
```

### Configuration Parameters

- **`host_cache_size`**: Specifies the amount of pinned host memory (in gigabytes) reserved for asynchronous checkpoint flushing. Adjust this value based on your system's memory capacity and the size of your model checkpoints.

## Implementing DataStates-LLM in Your Training Script

After enabling datastates checkpointing the `deepspeed_config.json`, the frequency of checkpointing can be configured by specifying the number of iterations after which the checkpoints should be captured using command-line parameter ` --save-interval`.

## Limitations and Ongoing Work

1. DataStates-LLM currently only supports the CUDA runtime on Nvidia-based GPUs.


2. DataStates-LLM has only been tested with ZeRO stage-1 without offloading to any other tiers.


3. While the checkpoint layout of datastates matches Huggingface's [safetensor](https://huggingface.co/docs/safetensors/) format, due to pickled objects required by DeepSpeed during restart, it is not fully compatible with safetensor library yet.

4. DataStates-LLM does not yet support universal or elastic checkpointing.


## Questions and Support

Please use the [DataStates-LLM Github repository](https://github.com/DataStates/datastates-llm) for any questions, issues, or feature requests.