1
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
2
3
- """ Helpers for defining sharding for optimizer states based on existing sharding for model parameters. """
3
+ """ Helpers for defining sharding for optimizer states based on existing sharding
4
+ for model parameters.
5
+ """
4
6
5
7
import logging
6
8
from copy import deepcopy
7
9
from dataclasses import replace
8
- from itertools import chain
9
- from typing import Dict , Iterable , List , Tuple , Union
10
+ from typing import Dict , Iterable , Tuple , Union
10
11
11
12
logger = logging .getLogger (__name__ )
12
13
13
14
import torch
14
15
16
+ from megatron .core .utils import to_local_if_dtensor
17
+
15
18
from .dict_utils import nested_values
16
19
from .mapping import (
17
20
LocalNonpersistentObject ,
24
27
25
28
26
29
def get_optim_param_to_id_map (optim_params_iter : Iterable [torch .nn .Parameter ]) -> Dict [int , int ]:
30
+ """Generate mapping from optimizer param to optimizer state id."""
27
31
param_mappings = {}
28
32
for i , param in enumerate (optim_params_iter ):
33
+ param = to_local_if_dtensor (param )
29
34
if id (param ) not in param_mappings :
30
35
param_mappings [id (param )] = i
31
36
return param_mappings
@@ -37,7 +42,8 @@ def get_param_id_to_sharded_param_map(
37
42
"""Generate mapping from optimizer state ids to model sharded parameters.
38
43
39
44
Args:
40
- model_sharded_state_dict: sharded state dict with all model sharded tensors (can have any structure)
45
+ model_sharded_state_dict: sharded state dict with all model sharded tensors
46
+ (can have any structure)
41
47
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
42
48
The iteration must be in the same order as in the optimizer parameters.
43
49
@@ -48,6 +54,9 @@ def get_param_id_to_sharded_param_map(
48
54
model_sharded_state_dict , _ = extract_sharded_tensors_and_factories (model_sharded_state_dict )
49
55
id_to_sharded_param_map = {}
50
56
param_to_id_map = get_optim_param_to_id_map (optim_params_iter )
57
+ # If using PyTorch FSDP2 the values in model_sharded_state_dict would
58
+ # have been converted to local tensors during initialization.
59
+ # See the make_(tp)_sharded_tensor_for_checkpoint functions.
51
60
for ten in nested_values (model_sharded_state_dict ):
52
61
if id (ten .data ) in param_to_id_map :
53
62
id_to_sharded_param_map [param_to_id_map [id (ten .data )]] = ten
@@ -76,12 +85,14 @@ def make_sharded_optimizer_tensor(
76
85
Returns:
77
86
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
78
87
"""
88
+ optim_param = to_local_if_dtensor (optim_param )
79
89
if isinstance (model_param , ShardedTensorFactory ):
80
90
return replace (model_param , key = f'{ prefix } .{ model_param .key } ' , data = optim_param )
81
91
82
- assert (
83
- tuple (optim_param .shape ) == model_param .local_shape
84
- ), f'Optimizer shape ({ tuple (optim_param .shape )} does not match model shape ({ model_param .local_shape } )'
92
+ assert tuple (optim_param .shape ) == model_param .local_shape , (
93
+ f'Optimizer shape ({ tuple (optim_param .shape )} does not match model shape '
94
+ f'({ model_param .local_shape } )'
95
+ )
85
96
sh_ten = replace (
86
97
model_param , key = f'{ prefix } .{ model_param .key } ' , data = optim_param , dtype = optim_param .dtype
87
98
)
@@ -102,9 +113,11 @@ def optim_state_to_sharding_state(
102
113
103
114
Args:
104
115
optim_state_dict (StateDict): optimizer state dict with
105
- state parameters under `state` key and group hyperparameters under `param_groups` -> `params` key.
106
- id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids to model sharded tensors.
107
- Can be generated with `get_param_id_to_sharded_param_map` function
116
+ state parameters under `state` key and group hyperparameters under
117
+ `param_groups` -> `params` key.
118
+ id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids
119
+ to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map`
120
+ function.
108
121
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
109
122
110
123
Returns:
0 commit comments