diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index d2f0d9277d2..ed8986ff817 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -80,9 +80,9 @@ TensorView* triu(TensorView* tv, Val* offset) { NVF_CHECK( dims >= 2, - "triu is only supported for 2+D tensors, but got ", + "input tensor for triu must have 2 or more dims, but got ", dims, - "D tensor"); + " dims"); auto fusion = tv->fusion(); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index ea061b094f1..fcfcd88e0a3 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1693,6 +1693,34 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag) #undef NVFUSER_PYTHON_BINDING_UNARY_OP + nvf_ops.def( + "triu", + [](FusionDefinition::Operators& self, + Tensor input, + int64_t diagonal) -> Tensor { + FUSER_PERF_SCOPE("Operators.triu"); + NVF_CHECK( + self.validUse(), "Attempting to add to a completed definition!"); + FusionDefinition* fd = self.fusion_definition; + Tensor output = fd->defineTensor(input.dims); + + auto diagonal_ = fd->defineScalar(); + fd->defineRecord(new ScalarRecord( + {fd->recordingState(diagonal_())}, diagonal, DataType::Int, true)); + + fd->defineRecord(new OpRecord( + {fd->recordingState(input()), fd->recordingState(diagonal_())}, + {fd->recordingState(output())}, + ("ops.triu"), + serde::RecordType::Binary_TV_VAL, + static_cast(triu))); + + return output; + }, + py::arg("input"), + py::arg("diagonal") = 0, + py::return_value_policy::reference); + // overload to nvf_ops.def( "stride_order", diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index 7e23adf2b69..253270f0b9d 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -653,6 +653,11 @@ void RecordFunctorFactory::setupFunctionMaps() { ("ops." op_str), static_cast(op_name)); \ unary_val.emplace(("ops." op_str), static_cast(op_name)); +#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \ + binary_tv_val.emplace( \ + ("ops." op_str), \ + static_cast(op_name)); + #define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \ binary_tv.emplace( \ ("ops." op_str), \ @@ -808,6 +813,8 @@ void RecordFunctorFactory::setupFunctionMaps() { NVFUSER_UNARY_TV_OP("real", real) NVFUSER_UNARY_TV_OP("imag", imag) + NVFUSER_UNARY_TV_ALPHA_OP("triu", triu) + NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul) NVFUSER_BINARY_TV_ONLY_OP("linear", linear) NVFUSER_TERNARY_TV_ONLY_OP("linear", linear) diff --git a/tests/python/opinfo_fusion_definitions.py b/tests/python/opinfo_fusion_definitions.py index 95abad9b7f4..c02ab27a620 100644 --- a/tests/python/opinfo_fusion_definitions.py +++ b/tests/python/opinfo_fusion_definitions.py @@ -28,10 +28,13 @@ def parse_inputs_fusion_definition(fd: FusionDefinition, opinfo: OpInfo, *args): ) num_symbolic_parameters = len(symbolic_parameter_list) - assert num_symbolic_parameters == len( + assert num_symbolic_parameters >= len( args ), f"{num_symbolic_parameters} vs {len(args)}" + if num_symbolic_parameters > len(args): + symbolic_parameter_list = symbolic_parameter_list[: len(args)] + for arg_type, a in zip(symbolic_parameter_list, args): if arg_type == ArgumentType.Symbolic: if isinstance(a, torch.Tensor): diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index d3222aea4b4..472d5109059 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -1591,3 +1591,39 @@ def div_input_generator( denom = torch.where(denom_is_small, denom_scaled_to_minabs, denom).detach() denom.requires_grad_(requires_grad) yield SampleInput(numer, denom) + + +def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False): + offsets = (0, 1, -1, 2, 3, -3, 1024, -1024) + + for element in elementwise_unary_generator( + op, + dtype, + requires_grad, + enable_extremal_value_testing=False, + enable_large_value_testing=False, + enable_small_value_testing=False, + ): + if element.args[0].ndim < 2: + continue + # to test cases where offset is not passed as an argument + yield element + # to test cases where offset is passed as an argument + for offset in offsets: + yield SampleInput(*element.args, offset) + + +def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False): + make_arg = partial( + make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad + ) + + invalid_shapes = ( + (), + (4,), + ) + + for shape in invalid_shapes: + yield SampleInput( + make_arg(shape), + ), RuntimeError, f"input tensor for triu must have 2 or more dims, but got {len(shape)} dims" diff --git a/tests/python/opinfos.py b/tests/python/opinfos.py index 9031a9bd091..f0bbd649b87 100644 --- a/tests/python/opinfos.py +++ b/tests/python/opinfos.py @@ -50,6 +50,8 @@ matmul_input_generator, linear_input_generator, linear_error_generator, + triu_input_generator, + triu_error_generator, ) from utils import ( bool_int_dtypes, @@ -1218,6 +1220,19 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): ) linear_ops.append(linear_opinfo) +tv_val_ops = [] + +triu_opinfo = OpInfo( + lambda fd: fd.ops.triu, + "triu", + sample_input_generator=triu_input_generator, + error_input_generator=triu_error_generator, + reference=torch.triu, + symbolic_parameter_list=[ArgumentType.Symbolic, ArgumentType.Constant], +) + +tv_val_ops.append(triu_opinfo) + """ End Tensor Creation """ # Puts all opinfos into the "opinfos" list @@ -1231,3 +1246,4 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): opinfos.extend(tensor_creation_ops) opinfos.extend(matmul_ops) opinfos.extend(linear_ops) +opinfos.extend(tv_val_ops) diff --git a/tests/python/test_ops.py b/tests/python/test_ops.py index d653e005736..ebd6a61114a 100644 --- a/tests/python/test_ops.py +++ b/tests/python/test_ops.py @@ -63,7 +63,9 @@ def parse_args_fusion_execution(opinfo: OpInfo, *args): else [ArgumentType.Symbolic] * len(args) ) - assert len(symbolic_parameter_list) == len(args) + assert len(symbolic_parameter_list) >= len(args) + if len(symbolic_parameter_list) > len(args): + symbolic_parameter_list = symbolic_parameter_list[: len(args)] result = [] for arg_type, a in zip(symbolic_parameter_list, args): diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 0b74fddeae6..7b7e1a40218 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -1204,6 +1204,20 @@ def fusion_func(fd: FusionDefinition): self.assertEqual(eager_out2, nvf_out[1]) # self.assertEqual(eager_out3, nvf_out[2]) + def test_triu(self): + inputs = [ + torch.randn(4, 16, device="cuda", dtype=torch.float16), + ] + + def fusion_func(fd: FusionDefinition) -> None: + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.ops.triu(t0, -1) + fd.add_output(t1) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out0 = torch.triu(inputs[0], -1) + self.assertEqual(eager_out0, nvf_out[0]) + def test_complex_rsqrt(self): inputs = [ torch.randn(4, device="cuda", dtype=torch.complex64),