Skip to content

Commit df9c33b

Browse files
authored
Enable TensorIndexer with some of the reduction-related tests (#5573)
Confirmed code diff results are benign
1 parent 02e0055 commit df9c33b

File tree

6 files changed

+43
-6
lines changed

6 files changed

+43
-6
lines changed

tests/cpp/test_outer_reduction.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232

3333
namespace nvfuser {
3434

35-
using OuterReductionTest = NVFuserTest;
35+
class OuterReductionTest : public NVFuserTest {
36+
protected:
37+
void SetUp() override {
38+
NVFuserTest::SetUp();
39+
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
40+
}
41+
};
3642

3743
using namespace at::indexing;
3844

tests/cpp/test_persistent_buffer.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ namespace nvfuser {
2323

2424
using testing::Contains;
2525
using testing::UnorderedElementsAre;
26-
using PersistentBufferTest = NVFuserTest;
26+
27+
class PersistentBufferTest : public NVFuserTest {
28+
protected:
29+
void SetUp() override {
30+
NVFuserTest::SetUp();
31+
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
32+
}
33+
};
2734

2835
TEST_F(PersistentBufferTest, FusionPersistentBufferCalculation1_CUDA) {
2936
Fusion fusion;

tests/cpp/test_reduction.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ void validateNoParallelBroadcastExist(kir::Kernel* kernel) {
7373

7474
} // namespace
7575

76-
using ReductionTest = NVFuserTest;
76+
class ReductionTest : public NVFuserTest {
77+
protected:
78+
void SetUp() override {
79+
NVFuserTest::SetUp();
80+
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
81+
}
82+
};
7783

7884
TEST_F(ReductionTest, GridAllreduce1) {
7985
const int nx = 999;

tests/cpp/test_reduction_pointwise.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
#include <tests/cpp/validator.h>
1717
namespace nvfuser {
1818

19-
using PointwiseFusedReductionTest = NVFuserTest;
19+
class PointwiseFusedReductionTest : public NVFuserTest {
20+
protected:
21+
void SetUp() override {
22+
NVFuserTest::SetUp();
23+
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
24+
}
25+
};
2026

2127
// inner reduction + non-broadcast epilogue, can't be fused
2228
// outer reduction + non-broadcast epilogue, can be fused

tests/cpp/test_serial_gridreduce.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333

3434
namespace nvfuser {
3535

36-
using SerialGridReductionTest = NVFuserTest;
36+
class SerialGridReductionTest : public NVFuserTest {
37+
protected:
38+
void SetUp() override {
39+
NVFuserTest::SetUp();
40+
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
41+
}
42+
};
3743

3844
TEST_F(SerialGridReductionTest, Scheduling) {
3945
for (bool serial : {true, false}) {

tests/cpp/test_welford.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515

1616
namespace nvfuser {
1717

18-
using WelfordTest = NVFuserTest;
18+
class WelfordTest : public NVFuserTest {
19+
protected:
20+
void SetUp() override {
21+
NVFuserTest::SetUp();
22+
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
23+
}
24+
};
1925

2026
TEST_F(WelfordTest, SerialWelford) {
2127
int x = 128, y = 64, z = 64;

0 commit comments

Comments
 (0)