From e92c5758182d048a547b6ff26cade1a9f96d8539 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 20 Dec 2024 10:18:42 -0800 Subject: [PATCH] new unit test --- csrc/ops/arith.cpp | 29 +++++++++++++++++++++++++++++ csrc/ops/arith.h | 1 + tests/cpp/test_tensor_factories.cpp | 22 ++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 7684e406fef..a858d723861 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -494,6 +494,35 @@ TensorView* eye(Val* size, DataType dtype) { return eye(size, size, dtype); } +TensorView* triu(TensorView* tv, Val* offset) { + NVF_CHECK( + offset->getDataType() == DataType::Int, "offset must have type Int"); + + NVF_CHECK( + tv->nDims() >= 2, + "triu is only supported for 2+D tensors, but got ", + tv->nDims(), + "D tensor"); + + auto tv_rows = iota( + tv->domain()->logical()[1]->extent(), + IrBuilder::create(0, DataType::Index), + IrBuilder::create(1, DataType::Index), + DataType::Int); + + auto tv_columns = iota( + tv->domain()->logical()[2]->extent(), + SimplifyingIrBuilder::mulExpr( + offset, IrBuilder::create(-1, DataType::Index)), + IrBuilder::create(1, DataType::Index), + DataType::Int); + + auto tv_rows_b = broadcast(tv_rows, {false, true}); + auto tv_cols_b = broadcast(tv_columns, {true, false}); + auto mask = le(tv_rows_b, tv_cols_b); + return where(mask, tv, IrBuilder::create(0, tv->dtype())); +} + // UNARY OPERATIONS #define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \ diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index d8ea10038ad..46134de1c95 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -251,6 +251,7 @@ NVF_API TensorView* arange( DataType dtype = DataType::Int); NVF_API TensorView* eye(Val* size, DataType dtype); NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype); +NVF_API TensorView* triu(TensorView* tv, Val* offset); // UNARY OPERATIONS // abs diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index 2eabde38b3b..b31a1506de3 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -230,6 +230,28 @@ TEST_F(TensorFactoryTest, StandaloneIota) { } } +TEST_F(TensorFactoryTest, SimpleTriu) { + auto fusion = std::make_unique(); + + FusionGuard gf(fusion.get()); + + auto tv_to_triu_on = makeSymbolicTensor(3, DataType::Half); + fusion->addInput(tv_to_triu_on); + + int64_t k_factor = -2; + auto out = triu(tv_to_triu_on, IrBuilder::create(k_factor, DataType::Int)); + fusion->addOutput(out); + + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto in_tensor = at::randn({4, 4, 8}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor}); + + EXPECT_TRUE(cg_outputs[0].allclose(at::triu(in_tensor, k_factor), .001, .001)); +} + TEST_F(TensorFactoryTest, StandaloneARange) { auto starts_ends = {-1., 0., 10.3, 1024. * 256}; auto steps = {-1.5, 1., 2.};