Skip to content

Commit 28a984e

Browse files
authored
Simplify and add README (deepspeedai#978)
Signed-off-by: Olatunji Ruwase <[email protected]>
1 parent b018de1 commit 28a984e

File tree

6 files changed

+63
-33
lines changed

6 files changed

+63
-33
lines changed

deepnvme/model_checkpoint/README

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
[FastPersist](https://arxiv.org/abs/2406.13768) is an optimization technique that leverages NVMe storage to accelerate model checkpointing. This folder contains micro-benchmarks and instructions for demonstrating FastPersist.
2+
3+
## Enabling FastPersist Optimizations ##
4+
FastPersist is designed to integrate with torch checkpointing and has been validated with torch version 2.6.0. This requires slight modifications to torch serialization, and for convenience we provide [original](torch/serialization_orig_v2.6.0.py) and [patched](torch/serialization_fast_v2.6.0.py) versions of serialization.py. Thus, to demonstrate FastPersist performance you need to overwrite `torch/serialization.py` in your torch installation with the patched version.
5+
6+
## Available Micro-benchmarks ##
7+
This folder contains three different micro-benchmarks that are implemented by the following scripts:
8+
1. torch_save_tensor.py: Serialize a raw pytorch tensor to disk using `torch.save()` integration.
9+
2. torch_save_model.py: Serialize a HF model to disk using `torch.save()` integration.
10+
3. deepspeed_save_model.py: Serialize a HF model to disk using DeepSped integration.
11+
12+
Each script provides a `--help` option to examine the available configurations. The scripts are written for single-process execution and so can be launched using `python`.
13+
14+
As an example, the performance of using the `torch.save()` integration of checkpointing HF phi-3-mini model from GPU memory can be measured as follows:
15+
```
16+
python torch_save_model.py --model phi3 --folder /mnt/nvme0 --gpu
17+
```
18+
19+
The script executes and reports the performance of the checkpointing workload using different mechanisms including vanilla `torch.save()`, FastPersist with CPU bounce buffer, FastPersist with NVIDIA GDS, etc. You can find the respective performance by searching the generated log for lines similar to the following snippet. For this example, the results below, collected using eight PCI Gen4 NVMes RAID-0 (data-striped), show checkpointing throughputs of 0.69GB/sec and 17.75GB/sec for vanilla `torch.save()` (labelled test_save) and FastPersist with CPU bounce buffer (labelled test_ds_aio_fast_save) respectively.
20+
21+
```bash
22+
test_save -- 14.23 GB, 20.72 secs, 0.69 GB/s
23+
test_ds_aio_fast_save -- 14.23 GB, 0.80 secs, 17.75 GB/s
24+
```
25+

deepnvme/model_checkpoint/deepspeed_save_model.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import deepspeed
99
from deepspeed.accelerator import get_accelerator
1010
from save_model_utils import get_model, validate_arguments, parse_arguments
11+
from torch_save_utils import load_io_ops
1112

1213
def _get_ds_config(args, writer_type, use_gds):
1314
ds_config = {
@@ -26,7 +27,7 @@ def _get_ds_config(args, writer_type, use_gds):
2627
}
2728
},
2829
"checkpoint": {
29-
"checkpoint_serialization": not args.legacy
30+
"checkpoint_serialization": args.zipfile
3031
},
3132
"aio": {
3233
"block_size": 8 * (1024**2),
@@ -64,11 +65,9 @@ def _do_optimizer_step(ds_engine):
6465

6566

6667
def _free_ds_memory(ds_engine):
67-
ds_engine.optimizer.optimizer = None
68-
ds_engine.optimizer = None
69-
ds_engine.module = None
70-
ds_engine = None
68+
ds_engine.destroy()
7169
del ds_engine
70+
ds_engine = None
7271
gc.collect()
7372
get_accelerator().empty_cache()
7473

@@ -80,9 +79,11 @@ def test_save(tag, folder, model, args, writer_type):
8079
if args.zero_stage == 0:
8180
_do_optimizer_step(ds_engine)
8281

82+
import pdb; pdb.set_trace()
8383
st = time.time()
8484
ds_engine.save_checkpoint(save_dir=folder, tag=tag)
8585
write_sec = time.time() - st
86+
import pdb; pdb.set_trace()
8687
_free_ds_memory(ds_engine)
8788
return write_sec
8889

@@ -107,8 +108,6 @@ def run(model, model_name, ckpt_name, args):
107108
folder = os.path.join(args.folder, ckpt_name, tag)
108109
if os.path.exists(folder):
109110
shutil.rmtree(folder, ignore_errors=True)
110-
# if not os.path.exists(folder):
111-
# os.makedirs(folder, exist_ok=True)
112111
write_sec = test_save(tag, folder, model, args, writer_type)
113112
ckpt_size = _get_folder_size(folder)
114113
gb_size = ckpt_size / (1024**3)
@@ -118,19 +117,32 @@ def run(model, model_name, ckpt_name, args):
118117
)
119118
print(f'*********************************************')
120119

120+
def init_torch_distributed():
121+
import torch.distributed as dist
122+
from deepspeed.constants import TORCH_DISTRIBUTED_DEFAULT_PORT, CROSS_RANK, CROSS_SIZE
123+
os.environ['MASTER_PORT'] = str(TORCH_DISTRIBUTED_DEFAULT_PORT)
124+
os.environ['MASTER_ADDR'] = "localhost"
125+
os.environ['LOCAL_RANK'] = str(0)
126+
os.environ['WORLD_SIZE'] = str(1)
127+
os.environ['CROSS_RANK'] = str(0)
128+
os.environ['CROSS_SIZE'] = str(1)
129+
dist.init_process_group(backend='nccl', rank=0, world_size=1)
130+
131+
121132

122133
def main():
123134
print(
124135
f'Performance test of deepspeed integration of fast model checkpointing.'
125136
)
126137
print(f'torch version = {torch.__version__}')
138+
init_torch_distributed()
127139
torch.manual_seed(42)
128140
np.random.seed(0)
129141
random.seed(0)
130142
args = parse_arguments()
131143
if not validate_arguments(args):
132144
quit()
133-
145+
load_io_ops(args)
134146
model, model_name, ckpt_name = get_model(args.model)
135147
run(model, model_name, ckpt_name, args)
136148

deepnvme/model_checkpoint/save_model_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def parse_arguments():
6767
default=0,
6868
help='Local rank' )
6969

70-
parser.add_argument('--legacy',
70+
parser.add_argument('--zipfile',
7171
action='store_true',
72-
help='Use torch legacy save format')
72+
help='Use torch zipfile save format')
7373

7474
parser.add_argument('--optimizer',
7575
action='store_true',

deepnvme/model_checkpoint/torch_save_model.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
import torch
33
from torch.optim import Adam
44
import os
5-
from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save
5+
from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save, load_io_ops
66
from save_model_utils import get_model, validate_arguments, parse_arguments
77
import deepspeed
88
from deepspeed.accelerator import get_accelerator
9-
import deepspeed.comm as dist
109

1110

1211
def run(model, model_name, ckpt_name, args):
@@ -23,8 +22,6 @@ def run(model, model_name, ckpt_name, args):
2322
continue
2423
file = os.path.join(args.folder, f'{tag}_{ckpt_name}.pt')
2524
print(f'checkpoint file = {file}')
26-
if os.path.isfile(file):
27-
os.remove(file)
2825
st = time.time()
2926
write_sec = fn(file, model, args)
3027
ckpt_size = os.path.getsize(file)
@@ -59,8 +56,7 @@ def main():
5956
args = parse_arguments()
6057
if not validate_arguments(args):
6158
quit()
62-
63-
deepspeed.init_distributed()
59+
load_io_ops(args)
6460
model, model_name, ckpt_name = get_model(args.model)
6561
if args.half:
6662
model = model.half()
@@ -72,7 +68,6 @@ def main():
7268
else:
7369
ckpt_state = {'model': model}
7470
run(ckpt_state, model_name, ckpt_name, args)
75-
dist.destroy_process_group()
7671

7772

7873
if __name__ == "__main__":

deepnvme/model_checkpoint/torch_save_tensor.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
import argparse
33
import torch
44
import os
5-
from torch_save_utils import PINNED_BUFFER_MB
5+
from torch_save_utils import PINNED_BUFFER_MB, load_io_ops
66
from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save
77
import deepspeed
88
from deepspeed.accelerator import get_accelerator
9-
import deepspeed.comm as dist
109
import os
1110

1211
def run(args):
@@ -28,8 +27,6 @@ def run(args):
2827
continue
2928
file = os.path.join(args.folder, f'{tag}_{args.mb_size}MB.pt')
3029
print(f'checkpoint file = {file}')
31-
if os.path.isfile(file):
32-
os.remove(file)
3330
st = time.time()
3431
write_sec = fn(file, buffer, args)
3532
gb_per_sec = args.mb_size / (1024.0 * write_sec)
@@ -53,9 +50,9 @@ def parse_arguments():
5350
default=None,
5451
required=True,
5552
help='Size of tensor to save in MB.')
56-
parser.add_argument('--legacy',
53+
parser.add_argument('--zipfile',
5754
action='store_true',
58-
help='Use torch legacy save format')
55+
help='Use torch zipfile save format')
5956

6057
parser.add_argument('--gpu', action='store_true', help='Use gpu tensors.')
6158

@@ -71,10 +68,6 @@ def parse_arguments():
7168
parser.add_argument('--single_io_buffer',
7269
action='store_true',
7370
help='Disable double buffering of i/o buffer.')
74-
parser.add_argument('--local_rank',
75-
type=int,
76-
default=0,
77-
help='Local rank' )
7871

7972
args = parser.parse_args()
8073
print(f'args = {args}')
@@ -89,9 +82,8 @@ def main():
8982
if not os.path.exists(args.folder):
9083
print(f'Invalid folder: {args.folder}')
9184
quit()
92-
deepspeed.init_distributed()
85+
load_io_ops(args)
9386
run(args)
94-
dist.destroy_process_group()
9587

9688

9789
if __name__ == "__main__":

deepnvme/model_checkpoint/torch_save_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
AIO_OVERLAP_EVENTS = False
1414
PINNED_BUFFER_MB = 64
1515

16+
def load_io_ops(args):
17+
if AsyncIOBuilder().is_compatible():
18+
AsyncIOBuilder().load(verbose=False)
19+
if args.gpu and GDSBuilder().is_compatible():
20+
GDSBuilder().load(verbose=False)
21+
1622

1723
def _get_aio_handle():
1824
h = AsyncIOBuilder().load(verbose=False).aio_handle(block_size=AIO_BLOCK_SIZE,
@@ -34,7 +40,7 @@ def test_save(file, buffer, args):
3440
st = time.time()
3541
torch.save(f=file,
3642
obj=buffer,
37-
_use_new_zipfile_serialization=not args.legacy)
43+
_use_new_zipfile_serialization=args.zipfile)
3844
return time.time() - st
3945

4046

@@ -43,7 +49,7 @@ def test_ds_mock_save(file, buffer, args):
4349
ds_mock_writer = MockFileWriter(file)
4450
torch.save(f=ds_mock_writer,
4551
obj=buffer,
46-
_use_new_zipfile_serialization=not args.legacy)
52+
_use_new_zipfile_serialization=args.zipfile)
4753
ds_mock_writer.close() # Force flush to storage
4854
write_sec = time.time() - st
4955
if not args.no_statistics:
@@ -56,7 +62,7 @@ def test_ds_py_save(file, buffer, args):
5662
ds_py_writer = PyFileWriter(file)
5763
torch.save(f=ds_py_writer,
5864
obj=buffer,
59-
_use_new_zipfile_serialization=not args.legacy)
65+
_use_new_zipfile_serialization=args.zipfile)
6066
ds_py_writer.close() # Force flush to storage
6167
write_sec = time.time() - st
6268
if not args.no_statistics:
@@ -96,7 +102,7 @@ def _test_ds_fast_save(file, buffer, args, use_gds):
96102
config=fast_writer_config)
97103
torch.save(f=ds_fast_writer,
98104
obj=buffer,
99-
_use_new_zipfile_serialization=not args.legacy)
105+
_use_new_zipfile_serialization=args.zipfile)
100106
ds_fast_writer.close() # Force flush to storage
101107
write_sec = time.time() - st
102108
if not args.no_statistics:

0 commit comments

Comments
 (0)