Skip to content

Commit

Permalink
new unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Dec 21, 2024
1 parent 49b0862 commit e92c575
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
29 changes: 29 additions & 0 deletions csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,35 @@ TensorView* eye(Val* size, DataType dtype) {
return eye(size, size, dtype);
}

TensorView* triu(TensorView* tv, Val* offset) {
NVF_CHECK(
offset->getDataType() == DataType::Int, "offset must have type Int");

NVF_CHECK(
tv->nDims() >= 2,
"triu is only supported for 2+D tensors, but got ",
tv->nDims(),
"D tensor");

auto tv_rows = iota(
tv->domain()->logical()[1]->extent(),
IrBuilder::create<Val>(0, DataType::Index),
IrBuilder::create<Val>(1, DataType::Index),
DataType::Int);

auto tv_columns = iota(
tv->domain()->logical()[2]->extent(),
SimplifyingIrBuilder::mulExpr(
offset, IrBuilder::create<Val>(-1, DataType::Index)),
IrBuilder::create<Val>(1, DataType::Index),
DataType::Int);

auto tv_rows_b = broadcast(tv_rows, {false, true});
auto tv_cols_b = broadcast(tv_columns, {true, false});
auto mask = le(tv_rows_b, tv_cols_b);
return where(mask, tv, IrBuilder::create<Val>(0, tv->dtype()));
}

// UNARY OPERATIONS

#define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \
Expand Down
1 change: 1 addition & 0 deletions csrc/ops/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ NVF_API TensorView* arange(
DataType dtype = DataType::Int);
NVF_API TensorView* eye(Val* size, DataType dtype);
NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype);
NVF_API TensorView* triu(TensorView* tv, Val* offset);

// UNARY OPERATIONS
// abs
Expand Down
22 changes: 22 additions & 0 deletions tests/cpp/test_tensor_factories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,28 @@ TEST_F(TensorFactoryTest, StandaloneIota) {
}
}

TEST_F(TensorFactoryTest, SimpleTriu) {
auto fusion = std::make_unique<Fusion>();

FusionGuard gf(fusion.get());

auto tv_to_triu_on = makeSymbolicTensor(3, DataType::Half);
fusion->addInput(tv_to_triu_on);

int64_t k_factor = -2;
auto out = triu(tv_to_triu_on, IrBuilder::create<Val>(k_factor, DataType::Int));
fusion->addOutput(out);


auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto in_tensor = at::randn({4, 4, 8}, options);

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor});

EXPECT_TRUE(cg_outputs[0].allclose(at::triu(in_tensor, k_factor), .001, .001));
}

TEST_F(TensorFactoryTest, StandaloneARange) {
auto starts_ends = {-1., 0., 10.3, 1024. * 256};
auto steps = {-1.5, 1., 2.};
Expand Down

0 comments on commit e92c575

Please sign in to comment.