Skip to content

Commit e0e76e8

Browse files
committed
parametrize in python test
1 parent 3d7b559 commit e0e76e8

File tree

2 files changed

+63
-79
lines changed

2 files changed

+63
-79
lines changed

tests/cpp/test_embedding_node.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using EmbeddingTest = NVFuserTest;
2020

2121
constexpr int64_t n = 5, s = 2;
2222

23-
TEST_F(EmbeddingTest, Basic) {
23+
TEST_F(EmbeddingTest, EmbeddingNode) {
2424
auto fusion = std::make_unique<Fusion>();
2525
FusionGuard fg(fusion.get());
2626
std::vector<int64_t> inp_shape({s});
@@ -45,19 +45,5 @@ TEST_F(EmbeddingTest, Basic) {
4545
FusionExecutorCache executor_cache(std::move(fusion));
4646
auto nvf_out = executor_cache.runFusionWithInputs({input, weight});
4747
EXPECT_TRUE(at::allclose(nvf_out[0], aten_out));
48-
}
49-
50-
// INSTANTIATE_TEST_SUITE_P(
51-
// LinearWithoutBias,
52-
// LinearNodeParametrizedTest,
53-
// testing::Combine(
54-
// testing::Values(
55-
// Sizes({k}),
56-
// Sizes({m, k}),
57-
// Sizes({b, m, k}),
58-
// Sizes({1, k}),
59-
// Sizes({b, 1, k})),
60-
// testing::Values(Sizes({n, k}), Sizes({1, k})),
61-
// testing::Values(std::nullopt)));
62-
48+
}
6349
} // namespace nvfuser

tests/python/test_embedding.py

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,69 +11,67 @@
1111
from functools import partial
1212
import torch.nn.functional as F
1313

14-
class TestEmbedding(NVFuserTest):
15-
def test_embedding(self):
16-
def fusion_func(
17-
fd: FusionDefinition,
18-
has_optional_inputs: list[bool],
19-
optional_inputs_dtypes: list[DataType]
20-
):
21-
input = fd.define_tensor(
22-
shape=[-1],
23-
contiguity=[True],
24-
dtype=DataType.Int,
25-
is_cpu=False,
26-
)
27-
weight = fd.define_tensor(
28-
shape=[-1, -1],
29-
contiguity=[True, True],
30-
dtype=DataType.BFloat16,
31-
is_cpu=False,
32-
)
33-
# padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
34-
optional_inputs = [None] * 5
35-
for idx in range(len(optional_inputs)):
36-
if has_optional_inputs[idx]:
37-
optional_inputs[idx] = fd.define_scalar(value=None, dtype=optional_inputs_dtypes[idx])
38-
out = fd.ops.embedding(input, weight, *optional_inputs)
39-
fd.add_output(out)
4014

41-
N, S = 10, 3
42-
input = torch.randint(N, (S,), dtype=torch.int64, device='cuda', requires_grad=False)
43-
weight = torch.randn(N, S, dtype=torch.bfloat16, device='cuda', requires_grad=True)
44-
45-
padding_idx_vals = [None, -1, -2]
46-
max_norm_vals = [None, 1e-5]
47-
norm_type_vals = [None, 2.0, 1.0]
48-
scale_grad_by_freq = [None, True]
49-
sparse = [None, False, True]
50-
optional_inputs_dtypes = [DataType.Int, DataType.Float, DataType.Float, DataType.Bool, DataType.Bool]
15+
@pytest.mark.parametrize("padding_idx", [None, -2])
16+
@pytest.mark.parametrize("max_norm", [None, 1e-5])
17+
@pytest.mark.parametrize("norm_type", [None, 1.0])
18+
@pytest.mark.parametrize("scale_grad_by_freq", [None, True])
19+
@pytest.mark.parametrize("sparse", [None, True])
20+
def test_embedding(
21+
padding_idx: None | int,
22+
max_norm: None | float,
23+
norm_type: None | float,
24+
scale_grad_by_freq: None | bool,
25+
sparse: None | bool
26+
):
27+
def fusion_func(
28+
fd: FusionDefinition,
29+
has_optional_inputs: list[bool],
30+
optional_inputs_dtypes: list[DataType]
31+
):
32+
input = fd.define_tensor(
33+
shape=[-1],
34+
contiguity=[True],
35+
dtype=DataType.Int,
36+
is_cpu=False,
37+
)
38+
weight = fd.define_tensor(
39+
shape=[-1, -1],
40+
contiguity=[True, True],
41+
dtype=DataType.BFloat16,
42+
is_cpu=False,
43+
)
44+
# padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
45+
optional_inputs = [None] * 5
46+
for idx in range(len(optional_inputs)):
47+
if has_optional_inputs[idx]:
48+
optional_inputs[idx] = fd.define_scalar(value=None, dtype=optional_inputs_dtypes[idx])
49+
out = fd.ops.embedding(input, weight, *optional_inputs)
50+
fd.add_output(out)
5151

52-
53-
# TODO: Try to move this to pytest_ops.py. Currently, it does not work since the API between nvFuser and torch differs.
54-
for padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse in itertools.product(
55-
padding_idx_vals, max_norm_vals, norm_type_vals, scale_grad_by_freq, sparse
56-
):
57-
with self.subTest(padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse):
58-
# Reset the FusionCache or the fusion would not recompile for all subtests, failing checks in exec_nvfuser.
59-
FusionCache.reset()
60-
optional_inputs = [padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse]
61-
has_optional_inputs = [None] * 5
62-
inputs = [input, weight]
63-
for idx, param in enumerate(optional_inputs):
64-
if param is not None:
65-
has_optional_inputs[idx] = True
66-
inputs.append(param)
67-
68-
with FusionDefinition() as fd:
69-
fusion_func(fd,
70-
has_optional_inputs=has_optional_inputs,
71-
optional_inputs_dtypes = optional_inputs_dtypes)
72-
nvf_out = fd.execute(inputs)
52+
N, S = 10, 3
53+
input = torch.randint(N, (S,), dtype=torch.int64, device='cuda', requires_grad=False)
54+
weight = torch.randn(N, S, dtype=torch.bfloat16, device='cuda', requires_grad=True)
55+
optional_inputs_dtypes = [DataType.Int, DataType.Float, DataType.Float, DataType.Bool, DataType.Bool]
7356

74-
torch.manual_seed(0)
75-
norm_type = 2.0 if norm_type is None else norm_type
76-
scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq
77-
sparse = False if sparse is None else sparse
78-
ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
79-
torch.testing.assert_close(nvf_out[0], ref_out)
57+
# This is not in pytest_ops.py since the torch API does not accept None values for some arguments.
58+
# Different inputs for nvfuser and torch API cannot be handled within OpInfo
59+
optional_inputs = [padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse]
60+
has_optional_inputs = [None] * 5
61+
inputs = [input, weight]
62+
for idx, param in enumerate(optional_inputs):
63+
if param is not None:
64+
has_optional_inputs[idx] = True
65+
inputs.append(param)
66+
67+
with FusionDefinition() as fd:
68+
fusion_func(fd,
69+
has_optional_inputs=has_optional_inputs,
70+
optional_inputs_dtypes= optional_inputs_dtypes)
71+
nvf_out = fd.execute(inputs)
72+
73+
norm_type = 2.0 if norm_type is None else norm_type
74+
scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq
75+
sparse = False if sparse is None else sparse
76+
ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
77+
torch.testing.assert_close(nvf_out[0], ref_out)

0 commit comments

Comments
 (0)