Skip to content

Commit 260141f

Browse files
committed
creating a C++ API for triu
1 parent e92c575 commit 260141f

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

csrc/ops/arith.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,14 +504,15 @@ TensorView* triu(TensorView* tv, Val* offset) {
504504
tv->nDims(),
505505
"D tensor");
506506

507+
auto dims = tv->domain()->logical().size();
507508
auto tv_rows = iota(
508-
tv->domain()->logical()[1]->extent(),
509+
tv->domain()->logical()[dims - 2]->extent(),
509510
IrBuilder::create<Val>(0, DataType::Index),
510511
IrBuilder::create<Val>(1, DataType::Index),
511512
DataType::Int);
512513

513514
auto tv_columns = iota(
514-
tv->domain()->logical()[2]->extent(),
515+
tv->domain()->logical()[dims - 1]->extent(),
515516
SimplifyingIrBuilder::mulExpr(
516517
offset, IrBuilder::create<Val>(-1, DataType::Index)),
517518
IrBuilder::create<Val>(1, DataType::Index),

tests/cpp/test_tensor_factories.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,25 +231,32 @@ TEST_F(TensorFactoryTest, StandaloneIota) {
231231
}
232232

233233
TEST_F(TensorFactoryTest, SimpleTriu) {
234-
auto fusion = std::make_unique<Fusion>();
235-
236-
FusionGuard gf(fusion.get());
234+
std::vector<std::vector<int64_t>> input_sizes = {
235+
{64, 64}, {4, 16}, {16, 4}, {16, 8, 32}};
236+
auto offsets = {0, 1, 2, -1, -2, 200, -200};
237237

238-
auto tv_to_triu_on = makeSymbolicTensor(3, DataType::Half);
239-
fusion->addInput(tv_to_triu_on);
238+
for (auto input_size : input_sizes) {
239+
for (auto offset : offsets) {
240+
auto fusion = std::make_unique<Fusion>();
241+
FusionGuard fg(fusion.get());
240242

241-
int64_t k_factor = -2;
242-
auto out = triu(tv_to_triu_on, IrBuilder::create<Val>(k_factor, DataType::Int));
243-
fusion->addOutput(out);
243+
auto tv_to_triu_on =
244+
makeSymbolicTensor(input_size.size(), DataType::Half);
245+
fusion->addInput(tv_to_triu_on);
244246

247+
auto out =
248+
triu(tv_to_triu_on, IrBuilder::create<Val>(offset, DataType::Int));
249+
fusion->addOutput(out);
245250

246-
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
247-
auto in_tensor = at::randn({4, 4, 8}, options);
251+
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
252+
auto in_tensor = at::randn(input_size, options);
248253

249-
FusionExecutorCache executor_cache(std::move(fusion));
250-
auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor});
254+
FusionExecutorCache executor_cache(std::move(fusion));
255+
auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor});
251256

252-
EXPECT_TRUE(cg_outputs[0].allclose(at::triu(in_tensor, k_factor), .001, .001));
257+
EXPECT_TRUE(at::equal(cg_outputs[0], at::triu(in_tensor, offset)));
258+
}
259+
}
253260
}
254261

255262
TEST_F(TensorFactoryTest, StandaloneARange) {

0 commit comments

Comments
 (0)