Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python bindings and tests for Triu #3637

Merged
merged 12 commits into from
Jan 8, 2025
4 changes: 2 additions & 2 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ TensorView* triu(TensorView* tv, Val* offset) {

NVF_CHECK(
dims >= 2,
"triu is only supported for 2+D tensors, but got ",
"input tensor for triu must have 2 or more dims, but got ",
dims,
"D tensor");
" dims");

auto fusion = tv->fusion();

Expand Down
28 changes: 28 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,34 @@ void initNvFuserPythonBindings(PyObject* module) {
NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
#undef NVFUSER_PYTHON_BINDING_UNARY_OP

nvf_ops.def(
"triu",
[](FusionDefinition::Operators& self,
Tensor input,
int64_t diagonal) -> Tensor {
FUSER_PERF_SCOPE("Operators.triu");
NVF_CHECK(
self.validUse(), "Attempting to add to a completed definition!");
FusionDefinition* fd = self.fusion_definition;
Tensor output = fd->defineTensor(input.dims);

auto diagonal_ = fd->defineScalar();
fd->defineRecord(new ScalarRecord(
{fd->recordingState(diagonal_())}, diagonal, DataType::Int, true));

fd->defineRecord(new OpRecord<TensorView*, TensorView*, Val*>(
{fd->recordingState(input()), fd->recordingState(diagonal_())},
{fd->recordingState(output())},
("ops.triu"),
serde::RecordType::Binary_TV_VAL,
static_cast<TensorView* (*)(TensorView*, Val*)>(triu)));

return output;
},
py::arg("input"),
py::arg("diagonal") = 0,
py::return_value_policy::reference);
protonu marked this conversation as resolved.
Show resolved Hide resolved

// overload to
nvf_ops.def(
"stride_order",
Expand Down
7 changes: 7 additions & 0 deletions csrc/serde/fusion_record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,11 @@ void RecordFunctorFactory::setupFunctionMaps() {
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
unary_val.emplace(("ops." op_str), static_cast<Val* (*)(Val*)>(op_name));

#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \
binary_tv_val.emplace( \
("ops." op_str), \
static_cast<TensorView* (*)(TensorView*, Val*)>(op_name));

#define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \
binary_tv.emplace( \
("ops." op_str), \
Expand Down Expand Up @@ -808,6 +813,8 @@ void RecordFunctorFactory::setupFunctionMaps() {
NVFUSER_UNARY_TV_OP("real", real)
NVFUSER_UNARY_TV_OP("imag", imag)

NVFUSER_UNARY_TV_ALPHA_OP("triu", triu)

NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul)
NVFUSER_BINARY_TV_ONLY_OP("linear", linear)
NVFUSER_TERNARY_TV_ONLY_OP("linear", linear)
Expand Down
5 changes: 4 additions & 1 deletion tests/python/opinfo_fusion_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ def parse_inputs_fusion_definition(fd: FusionDefinition, opinfo: OpInfo, *args):
)

num_symbolic_parameters = len(symbolic_parameter_list)
assert num_symbolic_parameters == len(
assert num_symbolic_parameters >= len(
protonu marked this conversation as resolved.
Show resolved Hide resolved
args
), f"{num_symbolic_parameters} vs {len(args)}"

if num_symbolic_parameters > len(args):
symbolic_parameter_list = symbolic_parameter_list[: len(args)]

protonu marked this conversation as resolved.
Show resolved Hide resolved
for arg_type, a in zip(symbolic_parameter_list, args):
if arg_type == ArgumentType.Symbolic:
if isinstance(a, torch.Tensor):
Expand Down
36 changes: 36 additions & 0 deletions tests/python/opinfo_input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,3 +1591,39 @@ def div_input_generator(
denom = torch.where(denom_is_small, denom_scaled_to_minabs, denom).detach()
denom.requires_grad_(requires_grad)
yield SampleInput(numer, denom)


def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
offsets = (0, 1, -1, 2, 3, -3, 1024, -1024)

for element in elementwise_unary_generator(
op,
dtype,
requires_grad,
enable_extremal_value_testing=False,
enable_large_value_testing=False,
enable_small_value_testing=False,
):
if element.args[0].ndim < 2:
continue
# to test cases where offset is not passed as an argument
yield element
protonu marked this conversation as resolved.
Show resolved Hide resolved
# to test cases where offset is passed as an argument
for offset in offsets:
yield SampleInput(*element.args, offset)


def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
make_arg = partial(
make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad
)

invalid_shapes = (
(),
(4,),
)

for shape in invalid_shapes:
yield SampleInput(
make_arg(shape),
), RuntimeError, f"input tensor for triu must have 2 or more dims, but got {len(shape)} dims"
16 changes: 16 additions & 0 deletions tests/python/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
matmul_input_generator,
linear_input_generator,
linear_error_generator,
triu_input_generator,
triu_error_generator,
)
from utils import (
bool_int_dtypes,
Expand Down Expand Up @@ -1218,6 +1220,19 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
)
linear_ops.append(linear_opinfo)

tv_val_ops = []

triu_opinfo = OpInfo(
lambda fd: fd.ops.triu,
"triu",
sample_input_generator=triu_input_generator,
error_input_generator=triu_error_generator,
reference=torch.triu,
symbolic_parameter_list=[ArgumentType.Symbolic, ArgumentType.Constant],
)

tv_val_ops.append(triu_opinfo)

""" End Tensor Creation """

# Puts all opinfos into the "opinfos" list
Expand All @@ -1231,3 +1246,4 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
opinfos.extend(tensor_creation_ops)
opinfos.extend(matmul_ops)
opinfos.extend(linear_ops)
opinfos.extend(tv_val_ops)
4 changes: 3 additions & 1 deletion tests/python/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def parse_args_fusion_execution(opinfo: OpInfo, *args):
else [ArgumentType.Symbolic] * len(args)
)

assert len(symbolic_parameter_list) == len(args)
assert len(symbolic_parameter_list) >= len(args)
protonu marked this conversation as resolved.
Show resolved Hide resolved
if len(symbolic_parameter_list) > len(args):
symbolic_parameter_list = symbolic_parameter_list[: len(args)]

result = []
for arg_type, a in zip(symbolic_parameter_list, args):
Expand Down
14 changes: 14 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,20 @@ def fusion_func(fd: FusionDefinition):
self.assertEqual(eager_out2, nvf_out[1])
# self.assertEqual(eager_out3, nvf_out[2])

def test_triu(self):
inputs = [
torch.randn(4, 16, device="cuda", dtype=torch.float16),
]

def fusion_func(fd: FusionDefinition) -> None:
t0 = fd.from_pytorch(inputs[0])
t1 = fd.ops.triu(t0, -1)
fd.add_output(t1)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out0 = torch.triu(inputs[0], -1)
self.assertEqual(eager_out0, nvf_out[0])

def test_complex_rsqrt(self):
inputs = [
torch.randn(4, device="cuda", dtype=torch.complex64),
Expand Down
Loading