Skip to content

Commit 6bd9dc6

Browse files
galrotemfacebook-github-bot
authored andcommitted
add method to get state dict
Reviewed By: JKSenthil Differential Revision: D54447983 fbshipit-source-id: b458639aab4bdf2825865304eda6a06d70600393
1 parent e0184bf commit 6bd9dc6

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import os
9+
from pathlib import Path
10+
11+
import pytest
12+
13+
import torch
14+
import torch.distributed as dist
15+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
16+
from torchsnapshot import Snapshot
17+
from torchsnapshot.test_utils import check_state_dict_eq, run_with_pet
18+
19+
20+
def _create_fsdp_model(
21+
seed: int,
22+
device: torch.device,
23+
) -> torch.nn.Module:
24+
torch.manual_seed(seed)
25+
model = torch.nn.Linear(32, 32)
26+
27+
fsdp_model = FSDP(
28+
module=model,
29+
device_id=device,
30+
)
31+
FSDP.set_state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT)
32+
return fsdp_model
33+
34+
35+
@pytest.mark.skipif(
36+
bool(not torch.cuda.is_available()), reason="The test requires GPUs to run."
37+
)
38+
@pytest.mark.skipif(
39+
bool(torch.cuda.device_count() < 2), reason="At least two GPUs are required."
40+
)
41+
@run_with_pet(nproc=2)
42+
def test_model_and_optim_fsdp(tmp_path: Path) -> None:
43+
dist.init_process_group(backend="nccl")
44+
local_rank = int(os.environ["LOCAL_RANK"])
45+
device = torch.device(f"cuda:{local_rank}")
46+
torch.cuda.set_device(device)
47+
48+
fsdp_model = _create_fsdp_model(17, device)
49+
50+
snapshot = Snapshot.take(
51+
path=str(tmp_path),
52+
app_state={"fsdp_model": fsdp_model},
53+
)
54+
state_dict_from_method = snapshot.get_state_dict_for_key("fsdp_model")
55+
FSDP.set_state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT)
56+
57+
full_state_dict = fsdp_model.state_dict()
58+
for k, v in full_state_dict.items():
59+
full_state_dict[k] = v.cpu()
60+
61+
assert check_state_dict_eq(full_state_dict, state_dict_from_method)

tests/test_state_dict.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import tempfile
9+
import unittest
10+
from typing import cast, Dict
11+
12+
import torch
13+
import torchsnapshot
14+
from torchsnapshot import Stateful
15+
16+
17+
class MyModule(torch.nn.Module):
18+
def __init__(self) -> None:
19+
super().__init__()
20+
self.foo = torch.nn.Parameter(torch.randn(20, 20))
21+
22+
23+
class MyStateful(Stateful):
24+
def __init__(self) -> None:
25+
self.foo = 1
26+
self.bar = "bar"
27+
28+
def state_dict(self) -> Dict[str, object]:
29+
return {"foo": self.foo, "bar": self.bar}
30+
31+
def load_state_dict(self, state_dict: Dict[str, object]) -> None:
32+
self.foo = cast(int, state_dict["foo"])
33+
self.bar = cast(str, state_dict["bar"])
34+
35+
36+
class StateDictTest(unittest.TestCase):
37+
def test_get_state_dict(self) -> None:
38+
my_module = MyModule()
39+
with tempfile.TemporaryDirectory() as path:
40+
torchsnapshot.Snapshot.take(
41+
path=path,
42+
app_state={"my_module": my_module},
43+
)
44+
snapshot = torchsnapshot.Snapshot(path)
45+
state_dict = snapshot.get_state_dict_for_key("my_module")
46+
self.assertTrue(torch.allclose(state_dict["foo"], my_module.foo))
47+
48+
def test_get_state_dict_with_invalid_key(self) -> None:
49+
my_module = MyModule()
50+
with tempfile.TemporaryDirectory() as path:
51+
torchsnapshot.Snapshot.take(
52+
path=path,
53+
app_state={"my_module": my_module},
54+
)
55+
snapshot = torchsnapshot.Snapshot(path)
56+
with self.assertRaisesRegex(
57+
AssertionError, "is absent in both manifest and flattened"
58+
):
59+
snapshot.get_state_dict_for_key("invalid_key")
60+
61+
def test_generic_stateful(self) -> None:
62+
my_stateful = MyStateful()
63+
my_stateful.foo = 2
64+
my_stateful.bar = "baz"
65+
with tempfile.TemporaryDirectory() as path:
66+
snapshot = torchsnapshot.Snapshot(path)
67+
snapshot.take(path, app_state={"my_stateful": my_stateful})
68+
state_dict = snapshot.get_state_dict_for_key("my_stateful")
69+
self.assertDictEqual(state_dict, my_stateful.state_dict())

torchsnapshot/snapshot.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,45 @@ def _validate_app_state(app_state: AppState) -> None:
680680
f"Expected Stateful in app_state for key {key}, got {value_type}."
681681
)
682682

683+
# pyre-fixme: inflate returns Dict[Any,Any]
684+
# Missing return annotation [3]: Return type must be specified as type that does not contain `Any`
685+
def get_state_dict_for_key(self, key: str) -> Dict[Any, Any]:
686+
"""
687+
Gets the state dict for a selected key in the snapshot.
688+
This is useful in case you want to get the state dict without loading it to the stateful.
689+
690+
Args:
691+
key (str): The key to get the state dict for. Assumes the key was stored as a topline
692+
key in the snapshot.
693+
694+
Returns:
695+
The state dict associated with the key.
696+
697+
Below is a usage example
698+
699+
.. code-block:: python
700+
701+
snapshot = Snapshot.take(path=..., app_state={"stateful_key": module})
702+
module_state_dict = snapshot.get_state_dict_for_key("stateful_key")
703+
"""
704+
event_loop = asyncio.new_event_loop()
705+
pg = PGWrapper(self.pg)
706+
707+
manifest, _ = get_manifest_for_rank(metadata=self.metadata, rank=pg.get_rank())
708+
709+
# filter out irrelevant entries from the manifest
710+
manifest = {k: v for k, v in manifest.items() if k.split("/")[0] == key}
711+
712+
storage = url_to_storage_plugin_in_event_loop(
713+
url_path=self.path,
714+
event_loop=event_loop,
715+
storage_options=self._storage_options,
716+
)
717+
718+
return self._get_state_dict_for_manifest(
719+
key, manifest, {}, pg, storage, event_loop
720+
)
721+
683722
def _load_stateful( # noqa
684723
self,
685724
stateful_key: str,

0 commit comments

Comments
 (0)