Skip to content

Commit 1ea4d0b

Browse files
committed
improve docs
1 parent 98ca9bf commit 1ea4d0b

File tree

10 files changed

+217
-13
lines changed

10 files changed

+217
-13
lines changed

docs/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
"python": ("https://docs.python.org/3", None),
6363
"docker": ("https://docker-py.readthedocs.io/en/stable/", None),
6464
"requests": ("https://requests.readthedocs.io/en/stable/", None),
65+
"torch": ("https://pytorch.org/docs/stable/", None),
66+
"safetensors": ("https://huggingface.co/docs/safetensors/main/en/", None),
6567
}
6668

6769
# By default, sort documented members by type within classes and modules.
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
distributed.checkpoint
2-
======================
1+
``distributed.checkpoint``
2+
==========================
33

44
.. automodule:: olmo_core.distributed.checkpoint
5-
:members: save_model_and_optim_state, load_model_and_optim_state, Checkpointer, StorageMetadata, TensorStorageMetadata
5+
:members: save_model_and_optim_state, load_model_and_optim_state, unshard_model_state, unshard_optim_state, Checkpointer, StorageMetadata, TensorStorageMetadata
66
:member-order: bysource

docs/source/exceptions.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
``exceptions``
2+
==============
3+
4+
.. automodule:: olmo_core.exceptions
5+
:members:
6+
:member-order: bysource

docs/source/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
:maxdepth: 2
1414
:caption: API Reference
1515

16+
exceptions.rst
17+
utils.rst
18+
io.rst
1619
distributed/checkpoint.rst
1720

1821
.. toctree::

docs/source/io.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
``io``
2+
======
3+
4+
.. automodule:: olmo_core.io
5+
:members:
6+
:member-order: bysource

docs/source/utils.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
``utils``
2+
=========
3+
4+
.. automodule:: olmo_core.utils
5+
:members:
6+
:member-order: bysource

src/olmo_core/distributed/checkpoint.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,34 @@
1+
"""
2+
A low-overhead, fast, distributed checkpointing module with a unified API for saving and
3+
loading both local and remote checkpoints. Built on top of `safetensors <https://huggingface.co/docs/safetensors/>`_
4+
and inspired by :mod:`torch.distributed.checkpoint`, but better suited for handling distributed models and
5+
optimizer state without unnecessary distributed communication and GPU allocations.
6+
7+
Features
8+
--------
9+
10+
- Sharded distributed models, such as PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
11+
are supported out-of-the-box.
12+
- Utilizes `safetensors <https://huggingface.co/docs/safetensors/>`_ under the hood for fast, efficient, and
13+
safe serialization/deserialization.
14+
- Save with one distributed topology, seamlessly load with a different one. For example,
15+
with FSDP you can save/load checkpoints with different world sizes or wrapping strategies.
16+
- Save/load directly to/from a remote object store like S3 or GCS. When loading from a remote object store each
17+
rank only downloads the fraction of the data it needs for its local (potentially sharded) tensors.
18+
- Checkpoints are always loaded in-place and one tensor at a time to avoid unnecessary allocations.
19+
This results in virtually no additional memory overhead.
20+
21+
Overview
22+
--------
23+
24+
Use :func:`save_model_and_optim_state()` to write a checkpoint with your model and optimizer's state, then
25+
use :func:`load_model_and_optim_state()` to load the checkpoint in-place. You can also generate unsharded, full
26+
state dictionaries from a checkpoint with :func:`unshard_model_state()` and :func:`unshard_optim_state()`.
27+
28+
API Reference
29+
-------------
30+
"""
31+
132
from __future__ import annotations
233

334
import json
@@ -51,6 +82,26 @@ def save_model_and_optim_state(
5182
a different distributed topology through :func:`load_model_and_optim_state()`.
5283
5384
Returns all of the files created by the current rank.
85+
86+
.. seealso::
87+
- :func:`load_model_and_optim_state()`
88+
- :func:`unshard_model_state()`
89+
- :func:`unshard_optim_state()`
90+
91+
.. tip::
92+
With :class:`~torch.distributed.fsdp.FullyShardedDataParallel` models it's not necessary
93+
to set the state dict type before calling this (or :func:`load_model_and_optim_state()`) via
94+
:meth:`~torch.distributed.fsdp.FullyShardedDataParallel.state_dict_type()` or other methods.
95+
In fact those settings will always be ignored.
96+
97+
.. attention::
98+
At the moment :class:`~torch.distributed.fsdp.FullyShardedDataParallel` models must have
99+
``use_orig_params=True``.
100+
101+
:param dir: Path/URL to save to.
102+
:param model: The model to save state from.
103+
:param optim: The optimizer to save state from.
104+
:param save_overwrite: Overwrite existing files.
54105
"""
55106
dir = str(dir).rstrip("/")
56107

@@ -77,10 +128,22 @@ def load_model_and_optim_state(
77128
"""
78129
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
79130
This method is agnostic to the distributed topology in that it can load checkpoints saved with a different
80-
distributed topology.
131+
distributed topology (e.g. FSDP vs DDP, or FSDP with a different world size).
132+
133+
.. seealso::
134+
- :func:`save_model_and_optim_state()`
135+
- :func:`unshard_model_state()`
136+
- :func:`unshard_optim_state()`
137+
138+
.. tip::
139+
Internally this function handles calling :meth:`torch.nn.Module.load_state_dict()` and
140+
:meth:`torch.optim.Optimizer.load_state_dict()` for you, hence the return type is ``None``.
141+
142+
:param dir: Path/URL to the checkpoint saved via :func:`save_model_and_optim_state()`.
143+
:param model: The model to load the state into.
144+
:param optim: The optimizer to load the state into.
81145
"""
82146
dir = str(dir).rstrip("/")
83-
84147
checkpointer = Checkpointer()
85148

86149
# Get model state in-place.
@@ -113,6 +176,51 @@ def load_model_and_optim_state(
113176
del optim_state_to_load
114177

115178

179+
@torch.no_grad()
180+
def unshard_model_state(
181+
dir: PathOrStr, device: Optional[torch.device] = None, rank0_only: bool = False, no_dist: bool = False
182+
) -> Dict[str, torch.Tensor]:
183+
"""
184+
Unshard model state saved via :func:`save_model_and_optim_state()`.
185+
186+
.. seealso::
187+
- :func:`unshard_optim_state()`
188+
189+
:param dir: Local or remote checkpoint directory.
190+
:param device: Device to load the checkpoint onto. Defaults to CPU.
191+
:param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed
192+
context. Other ranks will receive an empty dictionary.
193+
:param no_dist: Set to true to avoid any distributed communication whatsoever.
194+
"""
195+
dir = str(dir).rstrip("/")
196+
checkpointer = Checkpointer()
197+
return checkpointer.unshard(f"{dir}/model", device=device, rank0_only=rank0_only, no_dist=no_dist)
198+
199+
200+
@torch.no_grad()
201+
def unshard_optim_state(
202+
dir: PathOrStr, device: Optional[torch.device] = None, rank0_only: bool = False, no_dist: bool = False
203+
) -> OptimStateDict:
204+
"""
205+
Unshard optimizer state saved via :func:`save_model_and_optim_state()`.
206+
207+
.. seealso::
208+
- :func:`unshard_model_state()`
209+
210+
:param dir: Local or remote checkpoint directory.
211+
:param device: Device to load the checkpoint onto. Defaults to CPU.
212+
:param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed
213+
context. Other ranks will receive an empty dictionary.
214+
:param no_dist: Set to true to avoid any distributed communication whatsoever.
215+
"""
216+
dir = str(dir).rstrip("/")
217+
checkpointer = Checkpointer()
218+
flat_optim_state = checkpointer.unshard(f"{dir}/optim", device=device, rank0_only=rank0_only, no_dist=no_dist)
219+
optim_state = _unflatten_optimizer_state(flat_optim_state)
220+
del flat_optim_state
221+
return optim_state
222+
223+
116224
class Checkpointer:
117225
"""
118226
A distributed checkpointer for saving and loading *non-nested* state dictionaries,
@@ -363,6 +471,12 @@ def unshard(
363471
364472
Alternatively, setting ``no_dist=True`` will return a full state dict from whatever process
365473
calls this.
474+
475+
:param dir: Local or remote checkpoint directory.
476+
:param device: Device to load the checkpoint onto. Defaults to CPU.
477+
:param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed
478+
context. Other ranks will receive an empty dictionary.
479+
:param no_dist: Set to true to avoid any distributed communication whatsoever.
366480
"""
367481
dir = self._normalize_dir(dir)
368482

src/olmo_core/io.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,19 @@
2727

2828

2929
def is_url(path: PathOrStr) -> bool:
30+
"""
31+
Check if a path is a URL.
32+
33+
:param path: Path-like object to check.
34+
"""
3035
return re.match(r"[a-z0-9]+://.*", str(path)) is not None
3136

3237

3338
def file_size(path: PathOrStr) -> int:
3439
"""
3540
Get the size of a local or remote file in bytes.
41+
42+
:param path: Path/URL to the file.
3643
"""
3744
if is_url(path):
3845
from urllib.parse import urlparse
@@ -52,31 +59,44 @@ def file_size(path: PathOrStr) -> int:
5259
return os.stat(path).st_size
5360

5461

55-
def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
56-
if is_url(source):
62+
def get_bytes_range(path: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
63+
"""
64+
Get a range of bytes from a file.
65+
66+
:param source: Path/URL to the file.
67+
:param bytes_start: Byte offset to start at.
68+
:param num_bytes: Number of bytes to get.
69+
"""
70+
if is_url(path):
5771
from urllib.parse import urlparse
5872

59-
parsed = urlparse(str(source))
73+
parsed = urlparse(str(path))
6074
if parsed.scheme == "gs":
6175
return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
6276
elif parsed.scheme in ("s3", "r2"):
6377
return _s3_get_bytes_range(
6478
parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
6579
)
6680
elif parsed.scheme in ("http", "https"):
67-
return _http_get_bytes_range(str(source), bytes_start, num_bytes)
81+
return _http_get_bytes_range(str(path), bytes_start, num_bytes)
6882
elif parsed.scheme == "file":
69-
return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
83+
return get_bytes_range(str(path).replace("file://", "", 1), bytes_start, num_bytes)
7084
else:
7185
raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
7286
else:
73-
with open(source, "rb") as f:
87+
with open(path, "rb") as f:
7488
f.seek(bytes_start)
7589
return f.read(num_bytes)
7690

7791

7892
def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
79-
"""Upload source file to a target location on GCS or S3."""
93+
"""
94+
Upload source file to a target location on GCS or S3.
95+
96+
:param source: Path to the file to upload.
97+
:param target: Target URL to upload to.
98+
:param save_overwrite: Overwrite any existing file.
99+
"""
80100
from urllib.parse import urlparse
81101

82102
source = Path(source)
@@ -93,6 +113,11 @@ def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
93113

94114

95115
def dir_is_empty(dir: PathOrStr) -> bool:
116+
"""
117+
Check if a local directory is empty. This also returns true if the directory does not exist.
118+
119+
:param dir: Path to the local directory.
120+
"""
96121
dir = Path(dir)
97122
if not dir.is_dir():
98123
return True
@@ -104,6 +129,11 @@ def dir_is_empty(dir: PathOrStr) -> bool:
104129

105130

106131
def file_exists(path: PathOrStr) -> bool:
132+
"""
133+
Check if a file exists.
134+
135+
:param path: Path/URL to a file.
136+
"""
107137
if is_url(path):
108138
from urllib.parse import urlparse
109139

@@ -133,6 +163,11 @@ def file_exists(path: PathOrStr) -> bool:
133163

134164

135165
def clear_directory(dir: PathOrStr):
166+
"""
167+
Clear out the contents of a local or remote directory. GCS (``gs://``) and S3 (``s3://``) URLs are supported.
168+
169+
:param dir: Path/URL to the directory.
170+
"""
136171
if is_url(dir):
137172
from urllib.parse import urlparse
138173

@@ -153,11 +188,21 @@ def clear_directory(dir: PathOrStr):
153188

154189

155190
def serialize_to_tensor(x: Any) -> torch.Tensor:
191+
"""
192+
Serialize an object to a byte tensor using pickle.
193+
194+
:param x: The pickeable object to serialize.
195+
"""
156196
serialized_bytes = pickle.dumps(x)
157197
return torch.frombuffer(bytearray(serialized_bytes), dtype=torch.uint8)
158198

159199

160200
def deserialize_from_tensor(data: torch.Tensor) -> Any:
201+
"""
202+
Deserialize an object from a byte tensor using pickle.
203+
204+
:param data: The byte tensor to deserialize.
205+
"""
161206
assert data.dtype == torch.uint8
162207
return pickle.loads(bytearray([int(x.item()) for x in data.flatten()]))
163208

src/olmo_core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def seed_all(seed: int):
141141

142142
def get_grad_norm(params: Iterable[nn.Parameter], norm_type: float) -> torch.Tensor:
143143
"""
144-
Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector.
144+
Return the gradient norm of parameters, where the gradients are viewed as a single vector.
145145
146146
The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream
147147
use of this return value is a reduction across ranks.

src/test/distributed/checkpoint_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
init_optimizer_state,
1717
load_model_and_optim_state,
1818
save_model_and_optim_state,
19+
unshard_model_state,
20+
unshard_optim_state,
1921
)
2022
from olmo_core.distributed.fsdp import FSDP
2123
from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter
@@ -446,6 +448,26 @@ def run_save_and_load_fsdp_model(dir, model_factory, model_data_factory, pre_ini
446448
for p1, p2 in zip(fsdp_model.parameters(), fsdp_model2.parameters()):
447449
torch.testing.assert_close(optim.state[p1], optim2.state[p2])
448450

451+
# Check unsharding model state.
452+
full_model_state = unshard_model_state(dir)
453+
assert full_model_state
454+
for name, param in fsdp_model.named_parameters():
455+
assert isinstance(param, ShardedFlatParameter)
456+
assert name in full_model_state
457+
assert full_model_state[name].shape == param.unsharded_shape
458+
459+
# Check unsharding optim state.
460+
full_optim_state = unshard_optim_state(dir)
461+
assert full_optim_state
462+
assert len(full_optim_state["param_groups"]) == len(optim.param_groups)
463+
for i, param in enumerate(fsdp_model.parameters()):
464+
assert isinstance(param, ShardedFlatParameter)
465+
assert i in full_optim_state["state"]
466+
state = full_optim_state["state"][i]
467+
assert state["step"].numel() == 1
468+
assert state["exp_avg"].shape == param.unsharded_shape
469+
assert state["exp_avg_sq"].shape == param.unsharded_shape
470+
449471

450472
@pytest.mark.parametrize("backend", BACKENDS)
451473
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)