|
24 | 24 | from colossalai.utils import get_current_device, get_non_persistent_buffers_set |
25 | 25 | from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat |
26 | 26 |
|
| 27 | +from .distributed_checkpoint_utils import ( |
| 28 | + create_model_metadata, |
| 29 | + is_pytorch_model_meta_dist_file, |
| 30 | + load_dist_model, |
| 31 | + save_dist_sharded_model, |
| 32 | + save_dist_unshard_model, |
| 33 | +) |
27 | 34 | from .general_checkpoint_io import GeneralCheckpointIO |
28 | 35 | from .index_file import CheckpointIndexFile |
29 | 36 | from .utils import ( |
|
47 | 54 | sharded_optimizer_loading_epilogue, |
48 | 55 | ) |
49 | 56 |
|
50 | | -from .distributed_checkpoint_utils import ( |
51 | | - save_dist_sharded_model, |
52 | | - save_dist_unshard_model, |
53 | | - load_dist_model, |
54 | | - is_pytorch_model_meta_dist_file, |
55 | | - create_model_metadata |
56 | | -) |
57 | | - |
58 | 57 | try: |
59 | 58 | from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX |
60 | 59 | except ImportError: |
@@ -244,9 +243,19 @@ def save_sharded_model( |
244 | 243 | return |
245 | 244 | dist_id = self.tp_size * self.pp_rank + self.tp_rank |
246 | 245 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) |
247 | | - save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) |
| 246 | + save_dist_sharded_model( |
| 247 | + model=model, |
| 248 | + model_metadata=model_metadata, |
| 249 | + checkpoint=checkpoint, |
| 250 | + prefix=prefix, |
| 251 | + size_per_shard=size_per_shard, |
| 252 | + use_safetensors=use_safetensors, |
| 253 | + use_async=use_async, |
| 254 | + dist_id=dist_id, |
| 255 | + pinned_state_dicts=self.pinned_state_dicts, |
| 256 | + ) |
248 | 257 | return |
249 | | - |
| 258 | + |
250 | 259 | model = model.unwrap() |
251 | 260 |
|
252 | 261 | if os.path.isfile(checkpoint): |
@@ -394,9 +403,15 @@ def load_sharded_model( |
394 | 403 |
|
395 | 404 | if is_pytorch_model_meta_dist_file(checkpoint_index_file): |
396 | 405 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) |
397 | | - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint_index_file, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) |
| 406 | + load_dist_model( |
| 407 | + model=model, |
| 408 | + model_metadata=model_metadata, |
| 409 | + checkpoint=checkpoint_index_file, |
| 410 | + low_cpu_mem_mode=low_cpu_mem_mode, |
| 411 | + num_threads=num_threads, |
| 412 | + ) |
398 | 413 | return |
399 | | - |
| 414 | + |
400 | 415 | model_before_wrapping = model # backup for model before wrapping |
401 | 416 | model = model.unwrap() |
402 | 417 |
|
@@ -792,9 +807,17 @@ def save_unsharded_model( |
792 | 807 | if self.dp_rank != 0 and self.sp_rank != 0: |
793 | 808 | return |
794 | 809 | dist_id = self.tp_size * self.pp_rank + self.tp_rank |
795 | | - save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) |
| 810 | + save_dist_unshard_model( |
| 811 | + model=model, |
| 812 | + model_metadata=model_metadata, |
| 813 | + checkpoint=checkpoint, |
| 814 | + use_safetensors=use_safetensors, |
| 815 | + use_async=use_async, |
| 816 | + dist_id=dist_id, |
| 817 | + pinned_state_dicts=self.pinned_state_dicts, |
| 818 | + ) |
796 | 819 | return |
797 | | - |
| 820 | + |
798 | 821 | model = model.unwrap() |
799 | 822 | if self.dp_rank != 0: |
800 | 823 | return |
@@ -867,7 +890,13 @@ def load_unsharded_model( |
867 | 890 | for filename in os.listdir(checkpoint): |
868 | 891 | if is_pytorch_model_meta_dist_file(filename): |
869 | 892 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) |
870 | | - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) |
| 893 | + load_dist_model( |
| 894 | + model=model, |
| 895 | + model_metadata=model_metadata, |
| 896 | + checkpoint=checkpoint, |
| 897 | + low_cpu_mem_mode=low_cpu_mem_mode, |
| 898 | + num_threads=num_threads, |
| 899 | + ) |
871 | 900 | return |
872 | 901 |
|
873 | 902 | strict = False |
@@ -1099,7 +1128,6 @@ def gather_from_sharded_optimizer_state( |
1099 | 1128 | dist.all_gather(gather_tensor, v, group=dp_group) |
1100 | 1129 | v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) |
1101 | 1130 |
|
1102 | | - |
1103 | 1131 | # Then gather TP shards. |
1104 | 1132 | partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) |
1105 | 1133 | if partition_dim is not None: |
|
0 commit comments