Skip to content

Commit e0184bf

Browse files
galrotemfacebook-github-bot
authored andcommitted
extract state dict fetching logic to a separate method
Reviewed By: JKSenthil Differential Revision: D54447843 fbshipit-source-id: 45281c6d5ccd3803652fb953a4664f9622bcaea5
1 parent 628cb97 commit e0184bf

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

torchsnapshot/snapshot.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
import random
1717
import sys
1818
import traceback
19+
from asyncio import AbstractEventLoop
1920

2021
from collections import defaultdict
2122
from datetime import timedelta
2223
from threading import Thread
23-
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar
24+
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar, Union
2425

2526
import torch
2627
import torch.distributed as dist
@@ -722,6 +723,27 @@ def _load_stateful( # noqa
722723
tensor_requests=list(flattened.keys()),
723724
)
724725

726+
# Build the originally saved state dict and use it to restore the stateful
727+
state_dict = self._get_state_dict_for_manifest(
728+
stateful_key, manifest, flattened, pg, storage, event_loop
729+
)
730+
731+
if isinstance(stateful, torch.nn.Module):
732+
stateful.load_state_dict(state_dict, strict=strict)
733+
else:
734+
stateful.load_state_dict(state_dict)
735+
736+
@staticmethod
737+
# pyre-fixme: inflate returns Dict[Any,Any]
738+
# Missing return annotation [3]: Return type must be specified as type that does not contain `Any`
739+
def _get_state_dict_for_manifest(
740+
stateful_key: str,
741+
manifest: Manifest,
742+
flattened: Dict[str, Union[torch.Tensor, ShardedTensor, DTensor]],
743+
pg: PGWrapper,
744+
storage: StoragePlugin,
745+
event_loop: AbstractEventLoop,
746+
) -> Dict[Any, Any]:
725747
container_entries = {}
726748
read_reqs: List[ReadReq] = []
727749
futs = {}
@@ -754,17 +776,12 @@ def _load_stateful( # noqa
754776
)
755777

756778
# Build the originally saved state dict and use it to restore the stateful
757-
state_dict = inflate(
779+
return inflate(
758780
manifest=container_entries,
759781
flattened={k: fut.obj for k, fut in futs.items()},
760782
prefix=stateful_key,
761783
)
762784

763-
if isinstance(stateful, torch.nn.Module):
764-
stateful.load_state_dict(state_dict, strict=strict)
765-
else:
766-
stateful.load_state_dict(state_dict)
767-
768785
@staticmethod
769786
def _write_snapshot_metadata(
770787
snapshot_metadata: SnapshotMetadata,

0 commit comments

Comments
 (0)