1
+ """
2
+ A low-overhead, fast, distributed checkpointing module with a unified API for saving and
3
+ loading both local and remote checkpoints. Built on top of `safetensors <https://huggingface.co/docs/safetensors/>`_
4
+ and inspired by :mod:`torch.distributed.checkpoint`, but better suited for handling distributed models and
5
+ optimizer state without unnecessary distributed communication and GPU allocations.
6
+
7
+ Features
8
+ --------
9
+
10
+ - Sharded distributed models, such as PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
11
+ are supported out-of-the-box.
12
+ - Utilizes `safetensors <https://huggingface.co/docs/safetensors/>`_ under the hood for fast, efficient, and
13
+ safe serialization/deserialization.
14
+ - Save with one distributed topology, seamlessly load with a different one. For example,
15
+ with FSDP you can save/load checkpoints with different world sizes or wrapping strategies.
16
+ - Save/load directly to/from a remote object store like S3 or GCS. When loading from a remote object store each
17
+ rank only downloads the fraction of the data it needs for its local (potentially sharded) tensors.
18
+ - Checkpoints are always loaded in-place and one tensor at a time to avoid unnecessary allocations.
19
+ This results in virtually no additional memory overhead.
20
+
21
+ Overview
22
+ --------
23
+
24
+ Use :func:`save_model_and_optim_state()` to write a checkpoint with your model and optimizer's state, then
25
+ use :func:`load_model_and_optim_state()` to load the checkpoint in-place. You can also generate unsharded, full
26
+ state dictionaries from a checkpoint with :func:`unshard_model_state()` and :func:`unshard_optim_state()`.
27
+
28
+ API Reference
29
+ -------------
30
+ """
31
+
1
32
from __future__ import annotations
2
33
3
34
import json
@@ -51,6 +82,26 @@ def save_model_and_optim_state(
51
82
a different distributed topology through :func:`load_model_and_optim_state()`.
52
83
53
84
Returns all of the files created by the current rank.
85
+
86
+ .. seealso::
87
+ - :func:`load_model_and_optim_state()`
88
+ - :func:`unshard_model_state()`
89
+ - :func:`unshard_optim_state()`
90
+
91
+ .. tip::
92
+ With :class:`~torch.distributed.fsdp.FullyShardedDataParallel` models it's not necessary
93
+ to set the state dict type before calling this (or :func:`load_model_and_optim_state()`) via
94
+ :meth:`~torch.distributed.fsdp.FullyShardedDataParallel.state_dict_type()` or other methods.
95
+ In fact those settings will always be ignored.
96
+
97
+ .. attention::
98
+ At the moment :class:`~torch.distributed.fsdp.FullyShardedDataParallel` models must have
99
+ ``use_orig_params=True``.
100
+
101
+ :param dir: Path/URL to save to.
102
+ :param model: The model to save state from.
103
+ :param optim: The optimizer to save state from.
104
+ :param save_overwrite: Overwrite existing files.
54
105
"""
55
106
dir = str (dir ).rstrip ("/" )
56
107
@@ -77,10 +128,22 @@ def load_model_and_optim_state(
77
128
"""
78
129
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
79
130
This method is agnostic to the distributed topology in that it can load checkpoints saved with a different
80
- distributed topology.
131
+ distributed topology (e.g. FSDP vs DDP, or FSDP with a different world size).
132
+
133
+ .. seealso::
134
+ - :func:`save_model_and_optim_state()`
135
+ - :func:`unshard_model_state()`
136
+ - :func:`unshard_optim_state()`
137
+
138
+ .. tip::
139
+ Internally this function handles calling :meth:`torch.nn.Module.load_state_dict()` and
140
+ :meth:`torch.optim.Optimizer.load_state_dict()` for you, hence the return type is ``None``.
141
+
142
+ :param dir: Path/URL to the checkpoint saved via :func:`save_model_and_optim_state()`.
143
+ :param model: The model to load the state into.
144
+ :param optim: The optimizer to load the state into.
81
145
"""
82
146
dir = str (dir ).rstrip ("/" )
83
-
84
147
checkpointer = Checkpointer ()
85
148
86
149
# Get model state in-place.
@@ -113,6 +176,51 @@ def load_model_and_optim_state(
113
176
del optim_state_to_load
114
177
115
178
179
+ @torch .no_grad ()
180
+ def unshard_model_state (
181
+ dir : PathOrStr , device : Optional [torch .device ] = None , rank0_only : bool = False , no_dist : bool = False
182
+ ) -> Dict [str , torch .Tensor ]:
183
+ """
184
+ Unshard model state saved via :func:`save_model_and_optim_state()`.
185
+
186
+ .. seealso::
187
+ - :func:`unshard_optim_state()`
188
+
189
+ :param dir: Local or remote checkpoint directory.
190
+ :param device: Device to load the checkpoint onto. Defaults to CPU.
191
+ :param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed
192
+ context. Other ranks will receive an empty dictionary.
193
+ :param no_dist: Set to true to avoid any distributed communication whatsoever.
194
+ """
195
+ dir = str (dir ).rstrip ("/" )
196
+ checkpointer = Checkpointer ()
197
+ return checkpointer .unshard (f"{ dir } /model" , device = device , rank0_only = rank0_only , no_dist = no_dist )
198
+
199
+
200
+ @torch .no_grad ()
201
+ def unshard_optim_state (
202
+ dir : PathOrStr , device : Optional [torch .device ] = None , rank0_only : bool = False , no_dist : bool = False
203
+ ) -> OptimStateDict :
204
+ """
205
+ Unshard optimizer state saved via :func:`save_model_and_optim_state()`.
206
+
207
+ .. seealso::
208
+ - :func:`unshard_model_state()`
209
+
210
+ :param dir: Local or remote checkpoint directory.
211
+ :param device: Device to load the checkpoint onto. Defaults to CPU.
212
+ :param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed
213
+ context. Other ranks will receive an empty dictionary.
214
+ :param no_dist: Set to true to avoid any distributed communication whatsoever.
215
+ """
216
+ dir = str (dir ).rstrip ("/" )
217
+ checkpointer = Checkpointer ()
218
+ flat_optim_state = checkpointer .unshard (f"{ dir } /optim" , device = device , rank0_only = rank0_only , no_dist = no_dist )
219
+ optim_state = _unflatten_optimizer_state (flat_optim_state )
220
+ del flat_optim_state
221
+ return optim_state
222
+
223
+
116
224
class Checkpointer :
117
225
"""
118
226
A distributed checkpointer for saving and loading *non-nested* state dictionaries,
@@ -363,6 +471,12 @@ def unshard(
363
471
364
472
Alternatively, setting ``no_dist=True`` will return a full state dict from whatever process
365
473
calls this.
474
+
475
+ :param dir: Local or remote checkpoint directory.
476
+ :param device: Device to load the checkpoint onto. Defaults to CPU.
477
+ :param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed
478
+ context. Other ranks will receive an empty dictionary.
479
+ :param no_dist: Set to true to avoid any distributed communication whatsoever.
366
480
"""
367
481
dir = self ._normalize_dir (dir )
368
482
0 commit comments