Skip to content

Commit 5af5f03

Browse files
committed
Add tests
1 parent 8604769 commit 5af5f03

File tree

4 files changed

+203
-0
lines changed

4 files changed

+203
-0
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
einops==0.6.1
2+
pytest==7.4.0
13
timm==0.9.2
24
torch==2.0.1

tests/__init__.py

Whitespace-only changes.

tests/test_softmoe.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import random
2+
from functools import partial
3+
4+
import pytest
5+
import torch
6+
from einops import rearrange
7+
from timm.models.vision_transformer import Attention
8+
from torch import nn
9+
10+
from soft_moe_pytorch import SoftMoELayerWrapper
11+
from soft_moe_pytorch.soft_moe import softmax
12+
13+
14+
def test_softmax():
15+
"""
16+
Test between custom multi-dim softmax and naive impl.
17+
"""
18+
for _ in range(20):
19+
# Single-dim
20+
x = torch.randn(2, 10, 10)
21+
y1 = softmax(x, dim=-1)
22+
y2 = torch.softmax(x, dim=-1)
23+
assert y1.size() == y2.size()
24+
assert torch.all(torch.isclose(y1, y2))
25+
26+
# Multi-dim
27+
x = torch.randn(2, 10, 10, 10)
28+
y1 = softmax(x, dim=(2, 3))
29+
y2 = rearrange(
30+
x.flatten(start_dim=2).softmax(dim=-1), "b m (n p) -> b m n p", n=10
31+
)
32+
assert y1.size() == y2.size()
33+
assert torch.all(torch.isclose(y1, y2))
34+
35+
36+
def test_soft_moe_layer_forward():
37+
"""
38+
Test forward with different layers
39+
"""
40+
for num_experts in [1, 4]:
41+
for slots_per_experts in [1, 2]:
42+
for dim in [16, 128]:
43+
f = SoftMoELayerWrapper(
44+
dim=dim,
45+
slots_per_expert=slots_per_experts,
46+
num_experts=num_experts,
47+
layer=nn.Linear,
48+
in_features=dim,
49+
out_features=32,
50+
)
51+
n = random.randint(1, 128)
52+
inp = torch.randn(1, n, dim)
53+
out = f(inp)
54+
assert list(out.shape) == [1, n, 32]
55+
assert not torch.isnan(out).any(), "Output included NaNs"
56+
57+
for num_experts in [1, 4]:
58+
for slots_per_experts in [1, 2]:
59+
for dim in [16, 128]:
60+
f = SoftMoELayerWrapper(
61+
dim=dim,
62+
slots_per_expert=slots_per_experts,
63+
num_experts=num_experts,
64+
layer=partial(Attention, dim=dim),
65+
)
66+
n = random.randint(1, 128)
67+
inp = torch.randn(1, n, dim)
68+
out = f(inp)
69+
assert list(out.shape) == [1, n, dim]
70+
assert not torch.isnan(out).any(), "Output included NaNs"
71+
72+
73+
def test_soft_moe_layer_input_wrong_features_channels():
74+
"""
75+
Test for error when input has wrong feature dim
76+
"""
77+
f = SoftMoELayerWrapper(
78+
dim=128,
79+
slots_per_expert=1,
80+
num_experts=16,
81+
layer=nn.Linear,
82+
in_features=128,
83+
out_features=32,
84+
)
85+
86+
with pytest.raises(AssertionError):
87+
inp = torch.randn(1, 16, 64)
88+
f(inp)
89+
90+
91+
def test_soft_moe_layer_input_wrong_dim():
92+
"""
93+
Test for error when input is not 3-dim
94+
"""
95+
f = SoftMoELayerWrapper(
96+
dim=128,
97+
slots_per_expert=1,
98+
num_experts=16,
99+
layer=nn.Linear,
100+
in_features=128,
101+
out_features=32,
102+
)
103+
104+
with pytest.raises(AssertionError):
105+
inp = torch.randn(1, 16, 64, 64)
106+
f(inp)
107+
108+
with pytest.raises(AssertionError):
109+
inp = torch.randn(1, 16)
110+
f(inp)

tests/test_vision_transformer.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import pytest
2+
import torch
3+
4+
from soft_moe_pytorch import (soft_moe_vit_base, soft_moe_vit_huge,
5+
soft_moe_vit_large, soft_moe_vit_small,
6+
soft_moe_vit_tiny)
7+
8+
9+
@pytest.mark.parametrize(
10+
"model",
11+
[soft_moe_vit_tiny],
12+
# [soft_moe_vit_tiny, soft_moe_vit_small, soft_moe_vit_base, soft_moe_vit_large, soft_moe_vit_huge],
13+
)
14+
def test_soft_moe_vit_forward(model):
15+
"""
16+
Test network forward pass
17+
"""
18+
for image_size in [128, 224]:
19+
for in_chans in [1, 3]:
20+
net = model(
21+
img_size=image_size,
22+
in_chans=in_chans,
23+
num_classes=10,
24+
)
25+
net.eval()
26+
27+
inp = torch.randn(1, in_chans, image_size, image_size)
28+
out = net(inp)
29+
30+
assert out.shape[0] == 1
31+
assert not torch.isnan(out).any(), "Output included NaNs"
32+
33+
34+
@pytest.mark.parametrize(
35+
"model",
36+
[soft_moe_vit_tiny],
37+
# [soft_moe_vit_tiny, soft_moe_vit_small, soft_moe_vit_base, soft_moe_vit_large, soft_moe_vit_huge],
38+
)
39+
def test_soft_moe_vit_backward(model):
40+
"""
41+
Test network backward pass
42+
"""
43+
image_size = 224
44+
num_classes = 10
45+
46+
net = model(img_size=image_size, num_classes=num_classes)
47+
num_params = sum([x.numel() for x in net.parameters()])
48+
net.train()
49+
50+
inp = torch.randn(1, 3, image_size, image_size)
51+
out = net(inp)
52+
53+
out.mean().backward()
54+
for n, x in net.named_parameters():
55+
assert x.grad is not None, f"No gradient for {n}"
56+
num_grad = sum([x.grad.numel() for x in net.parameters() if x.grad is not None])
57+
58+
assert out.shape[-1] == num_classes
59+
assert num_params == num_grad, "Some parameters are missing gradients"
60+
assert not torch.isnan(out).any(), "Output included NaNs"
61+
62+
63+
@pytest.mark.parametrize(
64+
"model",
65+
[soft_moe_vit_tiny],
66+
# [soft_moe_vit_tiny, soft_moe_vit_small, soft_moe_vit_base, soft_moe_vit_large, soft_moe_vit_huge],
67+
)
68+
def test_soft_moe_vit_forward_num_experts(model):
69+
"""
70+
Test network soft-moe arguments
71+
"""
72+
image_size = 224
73+
in_chans = 3
74+
for num_experts in [1, 4]:
75+
for slots_per_experts in [1, 2]:
76+
for moe_layer_index in [6, [0, 2, 10]]:
77+
net = model(
78+
img_size=image_size,
79+
in_chans=in_chans,
80+
num_classes=10,
81+
num_experts=num_experts,
82+
slots_per_expert=slots_per_experts,
83+
moe_layer_index=moe_layer_index,
84+
)
85+
net.eval()
86+
87+
inp = torch.randn(1, in_chans, image_size, image_size)
88+
out = net(inp)
89+
90+
assert out.shape[0] == 1
91+
assert not torch.isnan(out).any(), "Output included NaNs"

0 commit comments

Comments
 (0)