Skip to content

Commit 1647f78

Browse files
authored
Add some more tests for nGPT (#113)
1 parent 37e0e88 commit 1647f78

File tree

1 file changed

+52
-16
lines changed

1 file changed

+52
-16
lines changed

src/test/nn/transformer/model_test.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import pytest
44
import torch
55
import torch.nn as nn
6+
from torch.distributed._tensor import DTensor
67

8+
from olmo_core.distributed.parallel import DataParallelType
79
from olmo_core.nn.layer_norm import LayerNorm
8-
from olmo_core.nn.transformer import TransformerConfig
10+
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
911

12+
from ...distributed.utils import requires_multi_gpu, run_distributed_test
1013
from ...utils import GPU_MARKS
1114

1215
log = logging.getLogger(__name__)
@@ -19,7 +22,7 @@
1922
pytest.param("cpu", "cpu", id="cpu->cpu"),
2023
],
2124
)
22-
def test_small_llama2_config_builder(init_device, device):
25+
def test_small_llama2_builder_config(init_device, device):
2326
config = TransformerConfig.llama2_271M(vocab_size=50257)
2427
log.info(config)
2528
model = config.build(init_device=init_device, device=torch.device(device))
@@ -44,14 +47,35 @@ def test_small_llama2_config_builder(init_device, device):
4447
assert model.blocks[-1].block_idx == len(model.blocks) - 1
4548

4649

50+
def check_ngpt_matrices(model: nn.Module, d_model: int):
51+
for name, module in model.named_modules():
52+
if isinstance(module, nn.Linear):
53+
assert module.bias is None
54+
55+
w = module.weight
56+
if isinstance(w, DTensor):
57+
w = w.full_tensor()
58+
59+
if w.shape[1] == d_model and "attention.w_out" not in name:
60+
pass
61+
elif w.shape[0] == d_model:
62+
w = w.transpose(0, 1)
63+
else:
64+
continue
65+
66+
log.info(f"Checking norm for '{name}'")
67+
norm = torch.linalg.vector_norm(w, dim=1)
68+
torch.testing.assert_close(norm, torch.ones_like(norm))
69+
70+
4771
@pytest.mark.parametrize(
4872
"init_device, device",
4973
[
5074
pytest.param("cpu", "cuda", id="cpu->cuda", marks=GPU_MARKS),
5175
pytest.param("cpu", "cpu", id="cpu->cpu"),
5276
],
5377
)
54-
def test_small_ngpt_config_builder(init_device, device):
78+
def test_small_ngpt_builder_config(init_device, device):
5579
config = TransformerConfig.ngpt_271M(vocab_size=50257)
5680
model = config.build(init_device=init_device, device=torch.device(device))
5781

@@ -67,17 +91,29 @@ def test_small_ngpt_config_builder(init_device, device):
6791
assert model.blocks[-1].block_idx == len(model.blocks) - 1
6892

6993
# Make sure all weights are normalized in the embedding dimension.
70-
for name, module in model.named_modules():
71-
if isinstance(module, nn.Linear):
72-
assert module.bias is None
73-
w = module.weight
74-
if w.shape[1] == config.d_model and "attention.w_out" not in name:
75-
pass
76-
elif w.shape[0] == config.d_model:
77-
w = w.transpose(0, 1)
78-
else:
79-
continue
94+
check_ngpt_matrices(model, config.d_model)
8095

81-
log.info(f"Checking norm for '{name}'")
82-
norm = torch.linalg.vector_norm(w, dim=1)
83-
torch.testing.assert_close(norm, torch.ones_like(norm))
96+
97+
def run_ngpt_with_fsdp2():
98+
config = TransformerConfig.ngpt_271M(
99+
vocab_size=50257,
100+
use_flash=False,
101+
dp_config=TransformerDataParallelConfig(name=DataParallelType.fsdp),
102+
)
103+
model = config.build(init_device="meta", max_seq_len=1024)
104+
optim = torch.optim.Adam(model.parameters())
105+
106+
# Take an optimizer step.
107+
model(input_ids=torch.randint(0, 50257, (2, 128))).sum().backward()
108+
optim.step()
109+
110+
# Re-normalize weights.
111+
model.normalize_matrices()
112+
113+
# Check that the re-normalization was successful.
114+
check_ngpt_matrices(model, config.d_model)
115+
116+
117+
@requires_multi_gpu
118+
def test_ngpt_with_fsdp2():
119+
run_distributed_test(run_ngpt_with_fsdp2, backend="nccl", start_method="spawn")

0 commit comments

Comments
 (0)