From f3081cb7a92814a38debefb76061a9d739d5e117 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 20 Dec 2024 10:18:42 -0800 Subject: [PATCH 01/10] new unit test --- csrc/ops/arith.cpp | 29 +++++++++++++++++++++++++++++ csrc/ops/arith.h | 1 + 2 files changed, 30 insertions(+) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 7684e406fef..a858d723861 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -494,6 +494,35 @@ TensorView* eye(Val* size, DataType dtype) { return eye(size, size, dtype); } +TensorView* triu(TensorView* tv, Val* offset) { + NVF_CHECK( + offset->getDataType() == DataType::Int, "offset must have type Int"); + + NVF_CHECK( + tv->nDims() >= 2, + "triu is only supported for 2+D tensors, but got ", + tv->nDims(), + "D tensor"); + + auto tv_rows = iota( + tv->domain()->logical()[1]->extent(), + IrBuilder::create(0, DataType::Index), + IrBuilder::create(1, DataType::Index), + DataType::Int); + + auto tv_columns = iota( + tv->domain()->logical()[2]->extent(), + SimplifyingIrBuilder::mulExpr( + offset, IrBuilder::create(-1, DataType::Index)), + IrBuilder::create(1, DataType::Index), + DataType::Int); + + auto tv_rows_b = broadcast(tv_rows, {false, true}); + auto tv_cols_b = broadcast(tv_columns, {true, false}); + auto mask = le(tv_rows_b, tv_cols_b); + return where(mask, tv, IrBuilder::create(0, tv->dtype())); +} + // UNARY OPERATIONS #define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \ diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index d8ea10038ad..46134de1c95 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -251,6 +251,7 @@ NVF_API TensorView* arange( DataType dtype = DataType::Int); NVF_API TensorView* eye(Val* size, DataType dtype); NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype); +NVF_API TensorView* triu(TensorView* tv, Val* offset); // UNARY OPERATIONS // abs From 3ee5d1c7cb08fd37dcaa928ff230a760bef3bd65 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 20 Dec 2024 22:35:46 -0800 Subject: [PATCH 02/10] creating a C++ API for triu --- csrc/ops/arith.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index a858d723861..093f02286b5 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -504,14 +504,30 @@ TensorView* triu(TensorView* tv, Val* offset) { tv->nDims(), "D tensor"); + // Let's say we want a triu of a 2D tensor of shape [2, 4] + // We broadcast the iota of the outer dim + // [0 [0, 0, 0, 0] + // 1] -> [1, 1, 1, 1] + // We broadcast the iota of the inner dim + // [0, 1, 2, 3] -> [0, 1, 2, 3] + // [0, 1, 2, 3] + // Using LE on the bcast tensors we get the mask + //[0, 0, 0, 0] LE [0, 1, 2, 3] + //[1, 1, 1, 1] [0, 1, 2, 3] + // Gives: + //[1, 0, 0, 0] + //[0, 1, 0, 0] + // If triu has an offset of k, we shift/subtract the iota of the columns by k + // before broadcasting and comparing with the iota of the rows. + auto dims = tv->domain()->logical().size(); auto tv_rows = iota( - tv->domain()->logical()[1]->extent(), + tv->domain()->logical()[dims - 2]->extent(), IrBuilder::create(0, DataType::Index), IrBuilder::create(1, DataType::Index), DataType::Int); auto tv_columns = iota( - tv->domain()->logical()[2]->extent(), + tv->domain()->logical()[dims - 1]->extent(), SimplifyingIrBuilder::mulExpr( offset, IrBuilder::create(-1, DataType::Index)), IrBuilder::create(1, DataType::Index), From fb1dd5c97450fb3b6591256386ce1c8ec345285f Mon Sep 17 00:00:00 2001 From: Protonu Date: Mon, 23 Dec 2024 16:49:05 -0500 Subject: [PATCH 03/10] correct comment --- csrc/ops/arith.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 093f02286b5..6fa7447d611 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -515,8 +515,8 @@ TensorView* triu(TensorView* tv, Val* offset) { //[0, 0, 0, 0] LE [0, 1, 2, 3] //[1, 1, 1, 1] [0, 1, 2, 3] // Gives: - //[1, 0, 0, 0] - //[0, 1, 0, 0] + //[1, 1, 1, 1] + //[0, 1, 1, 1] // If triu has an offset of k, we shift/subtract the iota of the columns by k // before broadcasting and comparing with the iota of the rows. auto dims = tv->domain()->logical().size(); From eb2556cf59970f75c535a80db7083105b992a6d8 Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 2 Jan 2025 12:55:02 -0800 Subject: [PATCH 04/10] moving from arith.cpp to composite.cpp --- csrc/ops/arith.cpp | 45 --------------------------------------------- csrc/ops/arith.h | 1 - 2 files changed, 46 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 6fa7447d611..7684e406fef 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -494,51 +494,6 @@ TensorView* eye(Val* size, DataType dtype) { return eye(size, size, dtype); } -TensorView* triu(TensorView* tv, Val* offset) { - NVF_CHECK( - offset->getDataType() == DataType::Int, "offset must have type Int"); - - NVF_CHECK( - tv->nDims() >= 2, - "triu is only supported for 2+D tensors, but got ", - tv->nDims(), - "D tensor"); - - // Let's say we want a triu of a 2D tensor of shape [2, 4] - // We broadcast the iota of the outer dim - // [0 [0, 0, 0, 0] - // 1] -> [1, 1, 1, 1] - // We broadcast the iota of the inner dim - // [0, 1, 2, 3] -> [0, 1, 2, 3] - // [0, 1, 2, 3] - // Using LE on the bcast tensors we get the mask - //[0, 0, 0, 0] LE [0, 1, 2, 3] - //[1, 1, 1, 1] [0, 1, 2, 3] - // Gives: - //[1, 1, 1, 1] - //[0, 1, 1, 1] - // If triu has an offset of k, we shift/subtract the iota of the columns by k - // before broadcasting and comparing with the iota of the rows. - auto dims = tv->domain()->logical().size(); - auto tv_rows = iota( - tv->domain()->logical()[dims - 2]->extent(), - IrBuilder::create(0, DataType::Index), - IrBuilder::create(1, DataType::Index), - DataType::Int); - - auto tv_columns = iota( - tv->domain()->logical()[dims - 1]->extent(), - SimplifyingIrBuilder::mulExpr( - offset, IrBuilder::create(-1, DataType::Index)), - IrBuilder::create(1, DataType::Index), - DataType::Int); - - auto tv_rows_b = broadcast(tv_rows, {false, true}); - auto tv_cols_b = broadcast(tv_columns, {true, false}); - auto mask = le(tv_rows_b, tv_cols_b); - return where(mask, tv, IrBuilder::create(0, tv->dtype())); -} - // UNARY OPERATIONS #define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \ diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 46134de1c95..d8ea10038ad 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -251,7 +251,6 @@ NVF_API TensorView* arange( DataType dtype = DataType::Int); NVF_API TensorView* eye(Val* size, DataType dtype); NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype); -NVF_API TensorView* triu(TensorView* tv, Val* offset); // UNARY OPERATIONS // abs From a86f0a1d149e8d156545202626f1713b9d5c831b Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 2 Jan 2025 13:19:06 -0800 Subject: [PATCH 05/10] fixes based on reviewer comments --- tests/cpp/test_tensor_factories.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index 3d95ad7d3c4..e0fdce92448 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -244,9 +244,15 @@ TEST_F(TensorFactoryTest, SimpleTriu) { auto input_offset = IrBuilder::create(DataType::Int); auto out = triu(tv_to_triu_on, input_offset); +<<<<<<< HEAD fusion->addInput(tv_to_triu_on); fusion->addInput(input_offset); fusion->addOutput(out); +======= + auto out = + triu(tv_to_triu_on, IrBuilder::create(offset, DataType::Index)); + fusion->addOutput(out); +>>>>>>> ad6021cf (fixes based on reviewer comments) FusionExecutorCache executor_cache(std::move(fusion)); From 7af3eabccd995a0a6501874589d98d873372c88f Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 3 Jan 2025 08:50:18 -0800 Subject: [PATCH 06/10] reviewer comments --- tests/cpp/test_tensor_factories.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index e0fdce92448..bfb33481e29 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -241,6 +241,7 @@ TEST_F(TensorFactoryTest, SimpleTriu) { FusionGuard fg(fusion.get()); auto tv_to_triu_on = makeSymbolicTensor(in.at(0).size(), DataType::Half); +<<<<<<< HEAD auto input_offset = IrBuilder::create(DataType::Int); auto out = triu(tv_to_triu_on, input_offset); @@ -253,6 +254,14 @@ TEST_F(TensorFactoryTest, SimpleTriu) { triu(tv_to_triu_on, IrBuilder::create(offset, DataType::Index)); fusion->addOutput(out); >>>>>>> ad6021cf (fixes based on reviewer comments) +======= + auto input_offset = IrBuilder::create(DataType::Index); + auto out = triu(tv_to_triu_on, input_offset); + + fusion->addInput(tv_to_triu_on); + fusion->addInput(input_offset); + fusion->addOutput(out); +>>>>>>> c90b21cd (reviewer comments) FusionExecutorCache executor_cache(std::move(fusion)); From 956ec3e0f5e1c944d35483da78c4e040c943f9f7 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 3 Jan 2025 09:33:44 -0800 Subject: [PATCH 07/10] offset should be dtype int --- tests/cpp/test_tensor_factories.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index bfb33481e29..f4d89c3677f 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -241,6 +241,7 @@ TEST_F(TensorFactoryTest, SimpleTriu) { FusionGuard fg(fusion.get()); auto tv_to_triu_on = makeSymbolicTensor(in.at(0).size(), DataType::Half); +<<<<<<< HEAD <<<<<<< HEAD auto input_offset = IrBuilder::create(DataType::Int); auto out = triu(tv_to_triu_on, input_offset); @@ -256,6 +257,9 @@ TEST_F(TensorFactoryTest, SimpleTriu) { >>>>>>> ad6021cf (fixes based on reviewer comments) ======= auto input_offset = IrBuilder::create(DataType::Index); +======= + auto input_offset = IrBuilder::create(DataType::Int); +>>>>>>> e40db84f (offset should be dtype int) auto out = triu(tv_to_triu_on, input_offset); fusion->addInput(tv_to_triu_on); From c0a8d3dfda5220f9c594744d85afa6cfda11b5b8 Mon Sep 17 00:00:00 2001 From: protonu Date: Sat, 21 Dec 2024 17:07:51 -0800 Subject: [PATCH 08/10] adding python bindings for triu --- csrc/ops/composite.cpp | 8 +++-- csrc/ops/composite.h | 1 + csrc/python_frontend/python_bindings.cpp | 32 ++++++++++++++++++++ csrc/serde/fusion_record.cpp | 9 ++++++ tests/python/opinfo_input_generators.py | 38 ++++++++++++++++++++++++ tests/python/opinfos.py | 15 ++++++++++ tests/python/test_python_frontend.py | 15 ++++++++++ 7 files changed, 116 insertions(+), 2 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index d2f0d9277d2..2288493e659 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -80,9 +80,9 @@ TensorView* triu(TensorView* tv, Val* offset) { NVF_CHECK( dims >= 2, - "triu is only supported for 2+D tensors, but got ", + "input tensor for triu must have 2 or more dims, but got ", dims, - "D tensor"); + " dims"); auto fusion = tv->fusion(); @@ -109,6 +109,10 @@ TensorView* triu(TensorView* tv, Val* offset) { return where(mask, tv, fusion->zeroVal(DataType::Index)); } +TensorView* triu(TensorView* tv) { + return triu(tv, IrBuilder::create(0)); +} + namespace { TensorView* newForLinear( diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index b67015b994d..8f1c994a242 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -36,6 +36,7 @@ NVF_API TensorView* dropout_backward( Val* scale); NVF_API TensorView* triu(TensorView* tv, Val* offset); +NVF_API TensorView* triu(TensorView* tv); struct LstmResult { TensorView* cell = nullptr; diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index ea061b094f1..436991f9b15 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1693,6 +1693,38 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag) #undef NVFUSER_PYTHON_BINDING_UNARY_OP + nvf_ops.def( + "triu", + [](FusionDefinition::Operators& self, + Tensor input, + std::optional offset) -> Tensor { + FUSER_PERF_SCOPE("Operators.triu"); + NVF_CHECK( + self.validUse(), "Attempting to add to a completed definition!"); + FusionDefinition* fd = self.fusion_definition; + Tensor output = fd->defineTensor(input.dims); + if (offset.has_value()) { + fd->defineRecord(new OpRecord( + {fd->recordingState(input()), + fd->recordingState(offset.value()())}, + {fd->recordingState(output())}, + ("ops.triu"), + serde::RecordType::Binary_TV_VAL, + static_cast(triu))); + } else { + fd->defineRecord(new OpRecord( + {fd->recordingState(input())}, + {fd->recordingState(output())}, + ("ops.triu"), + serde::RecordType::Unary_TV, + static_cast(triu))); + } + return output; + }, + py::arg("input"), + py::arg("offset") = std::nullopt, + py::return_value_policy::reference); + // overload to nvf_ops.def( "stride_order", diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index 7e23adf2b69..b732fee82d7 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -653,6 +653,13 @@ void RecordFunctorFactory::setupFunctionMaps() { ("ops." op_str), static_cast(op_name)); \ unary_val.emplace(("ops." op_str), static_cast(op_name)); +#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \ + unary_tv.emplace( \ + ("ops." op_str), static_cast(op_name)); \ + binary_tv_val.emplace( \ + ("ops." op_str), \ + static_cast(op_name)); + #define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \ binary_tv.emplace( \ ("ops." op_str), \ @@ -808,6 +815,8 @@ void RecordFunctorFactory::setupFunctionMaps() { NVFUSER_UNARY_TV_OP("real", real) NVFUSER_UNARY_TV_OP("imag", imag) + NVFUSER_UNARY_TV_ALPHA_OP("triu", triu) + NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul) NVFUSER_BINARY_TV_ONLY_OP("linear", linear) NVFUSER_TERNARY_TV_ONLY_OP("linear", linear) diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index d3222aea4b4..bcd0a7008ec 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -638,6 +638,8 @@ def elementwise_unary_generator( # Typical inputs for shape in shapes: + if op.name == "triu" and len(shape) < 2: + continue yield SampleInput(make_arg(shape)) yield SampleInput(make_arg(shape, noncontiguous=True)) @@ -1591,3 +1593,39 @@ def div_input_generator( denom = torch.where(denom_is_small, denom_scaled_to_minabs, denom).detach() denom.requires_grad_(requires_grad) yield SampleInput(numer, denom) + + +def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False): + offsets = (0, 1, -1, 2, 3, -3, 1024, -1024) + + for element in elementwise_unary_generator( + op, + dtype, + requires_grad, + enable_extremal_value_testing=False, + enable_large_value_testing=False, + enable_small_value_testing=False, + ): + yield element + for offset in offsets: + yield SampleInput(*element.args, offset) + + +def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False): + make_arg = partial( + make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad + ) + + yield SampleInput( + make_arg((4, 16)), 5.6 + ), RuntimeError, "offset must have integral type" + + invalid_shapes = ( + (), + (4,), + ) + + for shape in invalid_shapes: + yield SampleInput( + make_arg(shape), + ), RuntimeError, f"input tensor for triu must have 2 or more dims, but got {len(shape)} dims" diff --git a/tests/python/opinfos.py b/tests/python/opinfos.py index 9031a9bd091..d7502784a89 100644 --- a/tests/python/opinfos.py +++ b/tests/python/opinfos.py @@ -50,6 +50,8 @@ matmul_input_generator, linear_input_generator, linear_error_generator, + triu_input_generator, + triu_error_generator, ) from utils import ( bool_int_dtypes, @@ -1218,6 +1220,18 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): ) linear_ops.append(linear_opinfo) +tv_val_ops = [] + +triu_opinfo = OpInfo( + lambda fd: fd.ops.triu, + "triu", + sample_input_generator=triu_input_generator, + error_input_generator=triu_error_generator, + reference=torch.triu, +) + +tv_val_ops.append(triu_opinfo) + """ End Tensor Creation """ # Puts all opinfos into the "opinfos" list @@ -1231,3 +1245,4 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): opinfos.extend(tensor_creation_ops) opinfos.extend(matmul_ops) opinfos.extend(linear_ops) +opinfos.extend(tv_val_ops) diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 0b74fddeae6..29fb0b897e7 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -1204,6 +1204,21 @@ def fusion_func(fd: FusionDefinition): self.assertEqual(eager_out2, nvf_out[1]) # self.assertEqual(eager_out3, nvf_out[2]) + def test_triu(self): + inputs = [ + torch.randn(4, 16, device="cuda", dtype=torch.float16), + ] + + def fusion_func(fd: FusionDefinition) -> None: + t0 = fd.from_pytorch(inputs[0]) + tt = fd.define_scalar(-1, dtype=DataType.Int) + t1 = fd.ops.triu(t0, tt) + fd.add_output(t1) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out0 = torch.triu(inputs[0], -1) + self.assertEqual(eager_out0, nvf_out[0]) + def test_complex_rsqrt(self): inputs = [ torch.randn(4, device="cuda", dtype=torch.complex64), From 8c6400b39ccf0f2806f792d25e0868dbb111389b Mon Sep 17 00:00:00 2001 From: protonu Date: Sat, 4 Jan 2025 08:05:06 -0800 Subject: [PATCH 09/10] minor change to pytho op test --- tests/python/opinfo_input_generators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index bcd0a7008ec..696dade595e 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -638,8 +638,6 @@ def elementwise_unary_generator( # Typical inputs for shape in shapes: - if op.name == "triu" and len(shape) < 2: - continue yield SampleInput(make_arg(shape)) yield SampleInput(make_arg(shape, noncontiguous=True)) @@ -1606,6 +1604,8 @@ def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = F enable_large_value_testing=False, enable_small_value_testing=False, ): + if element.args[0].ndim < 2: + continue yield element for offset in offsets: yield SampleInput(*element.args, offset) From 8339c17f4e125ed29b4825c60fbeb0b63199af03 Mon Sep 17 00:00:00 2001 From: protonu Date: Mon, 6 Jan 2025 16:05:49 -0800 Subject: [PATCH 10/10] rebase and address reviwer comments --- csrc/ops/composite.cpp | 4 --- csrc/ops/composite.h | 1 - csrc/python_frontend/python_bindings.cpp | 32 ++++++++++------------- csrc/serde/fusion_record.cpp | 8 +++--- tests/cpp/test_tensor_factories.cpp | 19 -------------- tests/python/opinfo_fusion_definitions.py | 5 +++- tests/python/opinfo_input_generators.py | 6 ++--- tests/python/opinfos.py | 1 + tests/python/test_ops.py | 4 ++- tests/python/test_python_frontend.py | 3 +-- 10 files changed, 28 insertions(+), 55 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 2288493e659..ed8986ff817 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -109,10 +109,6 @@ TensorView* triu(TensorView* tv, Val* offset) { return where(mask, tv, fusion->zeroVal(DataType::Index)); } -TensorView* triu(TensorView* tv) { - return triu(tv, IrBuilder::create(0)); -} - namespace { TensorView* newForLinear( diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index 8f1c994a242..b67015b994d 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -36,7 +36,6 @@ NVF_API TensorView* dropout_backward( Val* scale); NVF_API TensorView* triu(TensorView* tv, Val* offset); -NVF_API TensorView* triu(TensorView* tv); struct LstmResult { TensorView* cell = nullptr; diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 436991f9b15..fcfcd88e0a3 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1697,32 +1697,28 @@ void initNvFuserPythonBindings(PyObject* module) { "triu", [](FusionDefinition::Operators& self, Tensor input, - std::optional offset) -> Tensor { + int64_t diagonal) -> Tensor { FUSER_PERF_SCOPE("Operators.triu"); NVF_CHECK( self.validUse(), "Attempting to add to a completed definition!"); FusionDefinition* fd = self.fusion_definition; Tensor output = fd->defineTensor(input.dims); - if (offset.has_value()) { - fd->defineRecord(new OpRecord( - {fd->recordingState(input()), - fd->recordingState(offset.value()())}, - {fd->recordingState(output())}, - ("ops.triu"), - serde::RecordType::Binary_TV_VAL, - static_cast(triu))); - } else { - fd->defineRecord(new OpRecord( - {fd->recordingState(input())}, - {fd->recordingState(output())}, - ("ops.triu"), - serde::RecordType::Unary_TV, - static_cast(triu))); - } + + auto diagonal_ = fd->defineScalar(); + fd->defineRecord(new ScalarRecord( + {fd->recordingState(diagonal_())}, diagonal, DataType::Int, true)); + + fd->defineRecord(new OpRecord( + {fd->recordingState(input()), fd->recordingState(diagonal_())}, + {fd->recordingState(output())}, + ("ops.triu"), + serde::RecordType::Binary_TV_VAL, + static_cast(triu))); + return output; }, py::arg("input"), - py::arg("offset") = std::nullopt, + py::arg("diagonal") = 0, py::return_value_policy::reference); // overload to diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index b732fee82d7..253270f0b9d 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -653,11 +653,9 @@ void RecordFunctorFactory::setupFunctionMaps() { ("ops." op_str), static_cast(op_name)); \ unary_val.emplace(("ops." op_str), static_cast(op_name)); -#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \ - unary_tv.emplace( \ - ("ops." op_str), static_cast(op_name)); \ - binary_tv_val.emplace( \ - ("ops." op_str), \ +#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \ + binary_tv_val.emplace( \ + ("ops." op_str), \ static_cast(op_name)); #define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \ diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index f4d89c3677f..3d95ad7d3c4 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -241,31 +241,12 @@ TEST_F(TensorFactoryTest, SimpleTriu) { FusionGuard fg(fusion.get()); auto tv_to_triu_on = makeSymbolicTensor(in.at(0).size(), DataType::Half); -<<<<<<< HEAD -<<<<<<< HEAD auto input_offset = IrBuilder::create(DataType::Int); auto out = triu(tv_to_triu_on, input_offset); -<<<<<<< HEAD fusion->addInput(tv_to_triu_on); fusion->addInput(input_offset); fusion->addOutput(out); -======= - auto out = - triu(tv_to_triu_on, IrBuilder::create(offset, DataType::Index)); - fusion->addOutput(out); ->>>>>>> ad6021cf (fixes based on reviewer comments) -======= - auto input_offset = IrBuilder::create(DataType::Index); -======= - auto input_offset = IrBuilder::create(DataType::Int); ->>>>>>> e40db84f (offset should be dtype int) - auto out = triu(tv_to_triu_on, input_offset); - - fusion->addInput(tv_to_triu_on); - fusion->addInput(input_offset); - fusion->addOutput(out); ->>>>>>> c90b21cd (reviewer comments) FusionExecutorCache executor_cache(std::move(fusion)); diff --git a/tests/python/opinfo_fusion_definitions.py b/tests/python/opinfo_fusion_definitions.py index 95abad9b7f4..c02ab27a620 100644 --- a/tests/python/opinfo_fusion_definitions.py +++ b/tests/python/opinfo_fusion_definitions.py @@ -28,10 +28,13 @@ def parse_inputs_fusion_definition(fd: FusionDefinition, opinfo: OpInfo, *args): ) num_symbolic_parameters = len(symbolic_parameter_list) - assert num_symbolic_parameters == len( + assert num_symbolic_parameters >= len( args ), f"{num_symbolic_parameters} vs {len(args)}" + if num_symbolic_parameters > len(args): + symbolic_parameter_list = symbolic_parameter_list[: len(args)] + for arg_type, a in zip(symbolic_parameter_list, args): if arg_type == ArgumentType.Symbolic: if isinstance(a, torch.Tensor): diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index 696dade595e..472d5109059 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -1606,7 +1606,9 @@ def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = F ): if element.args[0].ndim < 2: continue + # to test cases where offset is not passed as an argument yield element + # to test cases where offset is passed as an argument for offset in offsets: yield SampleInput(*element.args, offset) @@ -1616,10 +1618,6 @@ def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = F make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad ) - yield SampleInput( - make_arg((4, 16)), 5.6 - ), RuntimeError, "offset must have integral type" - invalid_shapes = ( (), (4,), diff --git a/tests/python/opinfos.py b/tests/python/opinfos.py index d7502784a89..f0bbd649b87 100644 --- a/tests/python/opinfos.py +++ b/tests/python/opinfos.py @@ -1228,6 +1228,7 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): sample_input_generator=triu_input_generator, error_input_generator=triu_error_generator, reference=torch.triu, + symbolic_parameter_list=[ArgumentType.Symbolic, ArgumentType.Constant], ) tv_val_ops.append(triu_opinfo) diff --git a/tests/python/test_ops.py b/tests/python/test_ops.py index d653e005736..ebd6a61114a 100644 --- a/tests/python/test_ops.py +++ b/tests/python/test_ops.py @@ -63,7 +63,9 @@ def parse_args_fusion_execution(opinfo: OpInfo, *args): else [ArgumentType.Symbolic] * len(args) ) - assert len(symbolic_parameter_list) == len(args) + assert len(symbolic_parameter_list) >= len(args) + if len(symbolic_parameter_list) > len(args): + symbolic_parameter_list = symbolic_parameter_list[: len(args)] result = [] for arg_type, a in zip(symbolic_parameter_list, args): diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 29fb0b897e7..7b7e1a40218 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -1211,8 +1211,7 @@ def test_triu(self): def fusion_func(fd: FusionDefinition) -> None: t0 = fd.from_pytorch(inputs[0]) - tt = fd.define_scalar(-1, dtype=DataType.Int) - t1 = fd.ops.triu(t0, tt) + t1 = fd.ops.triu(t0, -1) fd.add_output(t1) nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)