3
3
import pytest
4
4
import torch
5
5
import torch .nn as nn
6
+ from torch .distributed ._tensor import DTensor
6
7
8
+ from olmo_core .distributed .parallel import DataParallelType
7
9
from olmo_core .nn .layer_norm import LayerNorm
8
- from olmo_core .nn .transformer import TransformerConfig
10
+ from olmo_core .nn .transformer import TransformerConfig , TransformerDataParallelConfig
9
11
12
+ from ...distributed .utils import requires_multi_gpu , run_distributed_test
10
13
from ...utils import GPU_MARKS
11
14
12
15
log = logging .getLogger (__name__ )
19
22
pytest .param ("cpu" , "cpu" , id = "cpu->cpu" ),
20
23
],
21
24
)
22
- def test_small_llama2_config_builder (init_device , device ):
25
+ def test_small_llama2_builder_config (init_device , device ):
23
26
config = TransformerConfig .llama2_271M (vocab_size = 50257 )
24
27
log .info (config )
25
28
model = config .build (init_device = init_device , device = torch .device (device ))
@@ -44,14 +47,35 @@ def test_small_llama2_config_builder(init_device, device):
44
47
assert model .blocks [- 1 ].block_idx == len (model .blocks ) - 1
45
48
46
49
50
+ def check_ngpt_matrices (model : nn .Module , d_model : int ):
51
+ for name , module in model .named_modules ():
52
+ if isinstance (module , nn .Linear ):
53
+ assert module .bias is None
54
+
55
+ w = module .weight
56
+ if isinstance (w , DTensor ):
57
+ w = w .full_tensor ()
58
+
59
+ if w .shape [1 ] == d_model and "attention.w_out" not in name :
60
+ pass
61
+ elif w .shape [0 ] == d_model :
62
+ w = w .transpose (0 , 1 )
63
+ else :
64
+ continue
65
+
66
+ log .info (f"Checking norm for '{ name } '" )
67
+ norm = torch .linalg .vector_norm (w , dim = 1 )
68
+ torch .testing .assert_close (norm , torch .ones_like (norm ))
69
+
70
+
47
71
@pytest .mark .parametrize (
48
72
"init_device, device" ,
49
73
[
50
74
pytest .param ("cpu" , "cuda" , id = "cpu->cuda" , marks = GPU_MARKS ),
51
75
pytest .param ("cpu" , "cpu" , id = "cpu->cpu" ),
52
76
],
53
77
)
54
- def test_small_ngpt_config_builder (init_device , device ):
78
+ def test_small_ngpt_builder_config (init_device , device ):
55
79
config = TransformerConfig .ngpt_271M (vocab_size = 50257 )
56
80
model = config .build (init_device = init_device , device = torch .device (device ))
57
81
@@ -67,17 +91,29 @@ def test_small_ngpt_config_builder(init_device, device):
67
91
assert model .blocks [- 1 ].block_idx == len (model .blocks ) - 1
68
92
69
93
# Make sure all weights are normalized in the embedding dimension.
70
- for name , module in model .named_modules ():
71
- if isinstance (module , nn .Linear ):
72
- assert module .bias is None
73
- w = module .weight
74
- if w .shape [1 ] == config .d_model and "attention.w_out" not in name :
75
- pass
76
- elif w .shape [0 ] == config .d_model :
77
- w = w .transpose (0 , 1 )
78
- else :
79
- continue
94
+ check_ngpt_matrices (model , config .d_model )
80
95
81
- log .info (f"Checking norm for '{ name } '" )
82
- norm = torch .linalg .vector_norm (w , dim = 1 )
83
- torch .testing .assert_close (norm , torch .ones_like (norm ))
96
+
97
+ def run_ngpt_with_fsdp2 ():
98
+ config = TransformerConfig .ngpt_271M (
99
+ vocab_size = 50257 ,
100
+ use_flash = False ,
101
+ dp_config = TransformerDataParallelConfig (name = DataParallelType .fsdp ),
102
+ )
103
+ model = config .build (init_device = "meta" , max_seq_len = 1024 )
104
+ optim = torch .optim .Adam (model .parameters ())
105
+
106
+ # Take an optimizer step.
107
+ model (input_ids = torch .randint (0 , 50257 , (2 , 128 ))).sum ().backward ()
108
+ optim .step ()
109
+
110
+ # Re-normalize weights.
111
+ model .normalize_matrices ()
112
+
113
+ # Check that the re-normalization was successful.
114
+ check_ngpt_matrices (model , config .d_model )
115
+
116
+
117
+ @requires_multi_gpu
118
+ def test_ngpt_with_fsdp2 ():
119
+ run_distributed_test (run_ngpt_with_fsdp2 , backend = "nccl" , start_method = "spawn" )
0 commit comments