Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions deepnvme/model_checkpoint/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[FastPersist](https://arxiv.org/abs/2406.13768) is an optimization technique that leverages NVMe storage to accelerate model checkpointing. This folder contains micro-benchmarks and instructions for demonstrating FastPersist.

## Enabling FastPersist Optimizations ##
FastPersist is designed to integrate with torch checkpointing and has been validated with torch version 2.6.0. This requires slight modifications to torch serialization, and for convenience we provide [original](torch/serialization_orig_v2.6.0.py) and [patched](torch/serialization_fast_v2.6.0.py) versions of serialization.py. Thus, to demonstrate FastPersist performance you need to overwrite `torch/serialization.py` in your torch installation with the patched version.

## Available Micro-benchmarks ##
This folder contains three different micro-benchmarks that are implemented by the following scripts:
1. torch_save_tensor.py: Serialize a raw pytorch tensor to disk using `torch.save()` integration.
2. torch_save_model.py: Serialize a HF model to disk using `torch.save()` integration.
3. deepspeed_save_model.py: Serialize a HF model to disk using DeepSped integration.

Each script provides a `--help` option to examine the available configurations. The scripts are written for single-process execution and so can be launched using `python`.

As an example, the performance of using the `torch.save()` integration of checkpointing HF phi-3-mini model from GPU memory can be measured as follows:
```
python torch_save_model.py --model phi3 --folder /mnt/nvme0 --gpu
```

The script executes and reports the performance of the checkpointing workload using different mechanisms including vanilla `torch.save()`, FastPersist with CPU bounce buffer, FastPersist with NVIDIA GDS, etc. You can find the respective performance by searching the generated log for lines similar to the following snippet. For this example, the results below, collected using eight PCI Gen4 NVMes RAID-0 (data-striped), show checkpointing throughputs of 0.69GB/sec and 17.75GB/sec for vanilla `torch.save()` (labelled test_save) and FastPersist with CPU bounce buffer (labelled test_ds_aio_fast_save) respectively.

```bash
test_save -- 14.23 GB, 20.72 secs, 0.69 GB/s
test_ds_aio_fast_save -- 14.23 GB, 0.80 secs, 17.75 GB/s
```

28 changes: 20 additions & 8 deletions deepnvme/model_checkpoint/deepspeed_save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import deepspeed
from deepspeed.accelerator import get_accelerator
from save_model_utils import get_model, validate_arguments, parse_arguments
from torch_save_utils import load_io_ops

def _get_ds_config(args, writer_type, use_gds):
ds_config = {
Expand All @@ -26,7 +27,7 @@ def _get_ds_config(args, writer_type, use_gds):
}
},
"checkpoint": {
"checkpoint_serialization": not args.legacy
"checkpoint_serialization": args.zipfile
},
"aio": {
"block_size": 8 * (1024**2),
Expand Down Expand Up @@ -64,11 +65,9 @@ def _do_optimizer_step(ds_engine):


def _free_ds_memory(ds_engine):
ds_engine.optimizer.optimizer = None
ds_engine.optimizer = None
ds_engine.module = None
ds_engine = None
ds_engine.destroy()
del ds_engine
ds_engine = None
gc.collect()
get_accelerator().empty_cache()

Expand All @@ -80,9 +79,11 @@ def test_save(tag, folder, model, args, writer_type):
if args.zero_stage == 0:
_do_optimizer_step(ds_engine)

import pdb; pdb.set_trace()
st = time.time()
ds_engine.save_checkpoint(save_dir=folder, tag=tag)
write_sec = time.time() - st
import pdb; pdb.set_trace()
_free_ds_memory(ds_engine)
return write_sec

Expand All @@ -107,8 +108,6 @@ def run(model, model_name, ckpt_name, args):
folder = os.path.join(args.folder, ckpt_name, tag)
if os.path.exists(folder):
shutil.rmtree(folder, ignore_errors=True)
# if not os.path.exists(folder):
# os.makedirs(folder, exist_ok=True)
write_sec = test_save(tag, folder, model, args, writer_type)
ckpt_size = _get_folder_size(folder)
gb_size = ckpt_size / (1024**3)
Expand All @@ -118,19 +117,32 @@ def run(model, model_name, ckpt_name, args):
)
print(f'*********************************************')

def init_torch_distributed():
import torch.distributed as dist
from deepspeed.constants import TORCH_DISTRIBUTED_DEFAULT_PORT, CROSS_RANK, CROSS_SIZE
os.environ['MASTER_PORT'] = str(TORCH_DISTRIBUTED_DEFAULT_PORT)
os.environ['MASTER_ADDR'] = "localhost"
os.environ['LOCAL_RANK'] = str(0)
os.environ['WORLD_SIZE'] = str(1)
os.environ['CROSS_RANK'] = str(0)
os.environ['CROSS_SIZE'] = str(1)
dist.init_process_group(backend='nccl', rank=0, world_size=1)



def main():
print(
f'Performance test of deepspeed integration of fast model checkpointing.'
)
print(f'torch version = {torch.__version__}')
init_torch_distributed()
torch.manual_seed(42)
np.random.seed(0)
random.seed(0)
args = parse_arguments()
if not validate_arguments(args):
quit()

load_io_ops(args)
model, model_name, ckpt_name = get_model(args.model)
run(model, model_name, ckpt_name, args)

Expand Down
4 changes: 2 additions & 2 deletions deepnvme/model_checkpoint/save_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def parse_arguments():
default=0,
help='Local rank' )

parser.add_argument('--legacy',
parser.add_argument('--zipfile',
action='store_true',
help='Use torch legacy save format')
help='Use torch zipfile save format')

parser.add_argument('--optimizer',
action='store_true',
Expand Down
9 changes: 2 additions & 7 deletions deepnvme/model_checkpoint/torch_save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import torch
from torch.optim import Adam
import os
from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save
from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save, load_io_ops
from save_model_utils import get_model, validate_arguments, parse_arguments
import deepspeed
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist


def run(model, model_name, ckpt_name, args):
Expand All @@ -23,8 +22,6 @@ def run(model, model_name, ckpt_name, args):
continue
file = os.path.join(args.folder, f'{tag}_{ckpt_name}.pt')
print(f'checkpoint file = {file}')
if os.path.isfile(file):
os.remove(file)
st = time.time()
write_sec = fn(file, model, args)
ckpt_size = os.path.getsize(file)
Expand Down Expand Up @@ -59,8 +56,7 @@ def main():
args = parse_arguments()
if not validate_arguments(args):
quit()

deepspeed.init_distributed()
load_io_ops(args)
model, model_name, ckpt_name = get_model(args.model)
if args.half:
model = model.half()
Expand All @@ -72,7 +68,6 @@ def main():
else:
ckpt_state = {'model': model}
run(ckpt_state, model_name, ckpt_name, args)
dist.destroy_process_group()


if __name__ == "__main__":
Expand Down
16 changes: 4 additions & 12 deletions deepnvme/model_checkpoint/torch_save_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import argparse
import torch
import os
from torch_save_utils import PINNED_BUFFER_MB
from torch_save_utils import PINNED_BUFFER_MB, load_io_ops
from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save
import deepspeed
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
import os

def run(args):
Expand All @@ -28,8 +27,6 @@ def run(args):
continue
file = os.path.join(args.folder, f'{tag}_{args.mb_size}MB.pt')
print(f'checkpoint file = {file}')
if os.path.isfile(file):
os.remove(file)
st = time.time()
write_sec = fn(file, buffer, args)
gb_per_sec = args.mb_size / (1024.0 * write_sec)
Expand All @@ -53,9 +50,9 @@ def parse_arguments():
default=None,
required=True,
help='Size of tensor to save in MB.')
parser.add_argument('--legacy',
parser.add_argument('--zipfile',
action='store_true',
help='Use torch legacy save format')
help='Use torch zipfile save format')

parser.add_argument('--gpu', action='store_true', help='Use gpu tensors.')

Expand All @@ -71,10 +68,6 @@ def parse_arguments():
parser.add_argument('--single_io_buffer',
action='store_true',
help='Disable double buffering of i/o buffer.')
parser.add_argument('--local_rank',
type=int,
default=0,
help='Local rank' )

args = parser.parse_args()
print(f'args = {args}')
Expand All @@ -89,9 +82,8 @@ def main():
if not os.path.exists(args.folder):
print(f'Invalid folder: {args.folder}')
quit()
deepspeed.init_distributed()
load_io_ops(args)
run(args)
dist.destroy_process_group()


if __name__ == "__main__":
Expand Down
14 changes: 10 additions & 4 deletions deepnvme/model_checkpoint/torch_save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
AIO_OVERLAP_EVENTS = False
PINNED_BUFFER_MB = 64

def load_io_ops(args):
if AsyncIOBuilder().is_compatible():
AsyncIOBuilder().load(verbose=False)
if args.gpu and GDSBuilder().is_compatible():
GDSBuilder().load(verbose=False)


def _get_aio_handle():
h = AsyncIOBuilder().load(verbose=False).aio_handle(block_size=AIO_BLOCK_SIZE,
Expand All @@ -34,7 +40,7 @@ def test_save(file, buffer, args):
st = time.time()
torch.save(f=file,
obj=buffer,
_use_new_zipfile_serialization=not args.legacy)
_use_new_zipfile_serialization=args.zipfile)
return time.time() - st


Expand All @@ -43,7 +49,7 @@ def test_ds_mock_save(file, buffer, args):
ds_mock_writer = MockFileWriter(file)
torch.save(f=ds_mock_writer,
obj=buffer,
_use_new_zipfile_serialization=not args.legacy)
_use_new_zipfile_serialization=args.zipfile)
ds_mock_writer.close() # Force flush to storage
write_sec = time.time() - st
if not args.no_statistics:
Expand All @@ -56,7 +62,7 @@ def test_ds_py_save(file, buffer, args):
ds_py_writer = PyFileWriter(file)
torch.save(f=ds_py_writer,
obj=buffer,
_use_new_zipfile_serialization=not args.legacy)
_use_new_zipfile_serialization=args.zipfile)
ds_py_writer.close() # Force flush to storage
write_sec = time.time() - st
if not args.no_statistics:
Expand Down Expand Up @@ -96,7 +102,7 @@ def _test_ds_fast_save(file, buffer, args, use_gds):
config=fast_writer_config)
torch.save(f=ds_fast_writer,
obj=buffer,
_use_new_zipfile_serialization=not args.legacy)
_use_new_zipfile_serialization=args.zipfile)
ds_fast_writer.close() # Force flush to storage
write_sec = time.time() - st
if not args.no_statistics:
Expand Down