|
16 | 16 | import random |
17 | 17 | import sys |
18 | 18 | import traceback |
| 19 | +from asyncio import AbstractEventLoop |
19 | 20 |
|
20 | 21 | from collections import defaultdict |
21 | 22 | from datetime import timedelta |
22 | 23 | from threading import Thread |
23 | | -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar |
| 24 | +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar, Union |
24 | 25 |
|
25 | 26 | import torch |
26 | 27 | import torch.distributed as dist |
@@ -722,6 +723,27 @@ def _load_stateful( # noqa |
722 | 723 | tensor_requests=list(flattened.keys()), |
723 | 724 | ) |
724 | 725 |
|
| 726 | + # Build the originally saved state dict and use it to restore the stateful |
| 727 | + state_dict = self._get_state_dict_for_manifest( |
| 728 | + stateful_key, manifest, flattened, pg, storage, event_loop |
| 729 | + ) |
| 730 | + |
| 731 | + if isinstance(stateful, torch.nn.Module): |
| 732 | + stateful.load_state_dict(state_dict, strict=strict) |
| 733 | + else: |
| 734 | + stateful.load_state_dict(state_dict) |
| 735 | + |
| 736 | + @staticmethod |
| 737 | + # pyre-fixme: inflate returns Dict[Any,Any] |
| 738 | + # Missing return annotation [3]: Return type must be specified as type that does not contain `Any` |
| 739 | + def _get_state_dict_for_manifest( |
| 740 | + stateful_key: str, |
| 741 | + manifest: Manifest, |
| 742 | + flattened: Dict[str, Union[torch.Tensor, ShardedTensor, DTensor]], |
| 743 | + pg: PGWrapper, |
| 744 | + storage: StoragePlugin, |
| 745 | + event_loop: AbstractEventLoop, |
| 746 | + ) -> Dict[Any, Any]: |
725 | 747 | container_entries = {} |
726 | 748 | read_reqs: List[ReadReq] = [] |
727 | 749 | futs = {} |
@@ -754,17 +776,12 @@ def _load_stateful( # noqa |
754 | 776 | ) |
755 | 777 |
|
756 | 778 | # Build the originally saved state dict and use it to restore the stateful |
757 | | - state_dict = inflate( |
| 779 | + return inflate( |
758 | 780 | manifest=container_entries, |
759 | 781 | flattened={k: fut.obj for k, fut in futs.items()}, |
760 | 782 | prefix=stateful_key, |
761 | 783 | ) |
762 | 784 |
|
763 | | - if isinstance(stateful, torch.nn.Module): |
764 | | - stateful.load_state_dict(state_dict, strict=strict) |
765 | | - else: |
766 | | - stateful.load_state_dict(state_dict) |
767 | | - |
768 | 785 | @staticmethod |
769 | 786 | def _write_snapshot_metadata( |
770 | 787 | snapshot_metadata: SnapshotMetadata, |
|
0 commit comments