Skip to content

Commit

Permalink
adding python bindings for triu
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Jan 2, 2025
1 parent ad6021c commit 130551d
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 0 deletions.
4 changes: 4 additions & 0 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ TensorView* triu(TensorView* tv, Val* offset) {
return where(mask, tv, IrBuilder::create<Val>(0, tv->dtype()));
}

TensorView* triu(TensorView* tv) {
return triu(tv, IrBuilder::create<Val>(0));
}

namespace {

TensorView* newForLinear(
Expand Down
1 change: 1 addition & 0 deletions csrc/ops/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ NVF_API TensorView* dropout_backward(
Val* scale);

NVF_API TensorView* triu(TensorView* tv, Val* offset);
NVF_API TensorView* triu(TensorView* tv);

struct LstmResult {
TensorView* cell = nullptr;
Expand Down
32 changes: 32 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,38 @@ void initNvFuserPythonBindings(PyObject* module) {
NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
#undef NVFUSER_PYTHON_BINDING_UNARY_OP

nvf_ops.def(
"triu",
[](FusionDefinition::Operators& self,
Tensor input,
std::optional<Scalar> offset) -> Tensor {
FUSER_PERF_SCOPE("Operators.triu");
NVF_CHECK(
self.validUse(), "Attempting to add to a completed definition!");
FusionDefinition* fd = self.fusion_definition;
Tensor output = fd->defineTensor(input.dims);
if (offset.has_value()) {
fd->defineRecord(new OpRecord<TensorView*, TensorView*, Val*>(
{fd->recordingState(input()),
fd->recordingState(offset.value()())},
{fd->recordingState(output())},
("ops.triu"),
serde::RecordType::Binary_TV_VAL,
static_cast<TensorView* (*)(TensorView*, Val*)>(triu)));
} else {
fd->defineRecord(new OpRecord<TensorView*, TensorView*>(
{fd->recordingState(input())},
{fd->recordingState(output())},
("ops.triu"),
serde::RecordType::Unary_TV,
static_cast<TensorView* (*)(TensorView*)>(triu)));
}
return output;
},
py::arg("input"),
py::arg("offset") = std::nullopt,
py::return_value_policy::reference);

// overload to
nvf_ops.def(
"stride_order",
Expand Down
9 changes: 9 additions & 0 deletions csrc/serde/fusion_record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,13 @@ void RecordFunctorFactory::setupFunctionMaps() {
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
unary_val.emplace(("ops." op_str), static_cast<Val* (*)(Val*)>(op_name));

#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \
unary_tv.emplace( \
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
binary_tv_val.emplace( \
("ops." op_str), \
static_cast<TensorView* (*)(TensorView*, Val*)>(op_name));

#define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \
binary_tv.emplace( \
("ops." op_str), \
Expand Down Expand Up @@ -808,6 +815,8 @@ void RecordFunctorFactory::setupFunctionMaps() {
NVFUSER_UNARY_TV_OP("real", real)
NVFUSER_UNARY_TV_OP("imag", imag)

NVFUSER_UNARY_TV_ALPHA_OP("triu", triu)

NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul)
NVFUSER_BINARY_TV_ONLY_OP("linear", linear)
NVFUSER_TERNARY_TV_ONLY_OP("linear", linear)
Expand Down
36 changes: 36 additions & 0 deletions tests/python/opinfo_input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,8 @@ def elementwise_unary_generator(

# Typical inputs
for shape in shapes:
if op.name == "triu" and len(shape) < 2:
continue
yield SampleInput(make_arg(shape))
yield SampleInput(make_arg(shape, noncontiguous=True))

Expand Down Expand Up @@ -1591,3 +1593,37 @@ def div_input_generator(
denom = torch.where(denom_is_small, denom_scaled_to_minabs, denom).detach()
denom.requires_grad_(requires_grad)
yield SampleInput(numer, denom)


def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
offsets = (0, 1, -1, 2, 3, -3, 1024, -1024)

for element in elementwise_unary_generator(
op,
dtype,
requires_grad,
enable_extremal_value_testing=False,
enable_large_value_testing=False,
enable_small_value_testing=False,
):
yield element
for offset in offsets:
yield SampleInput(*element.args, offset)


def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
make_arg = partial(
make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad
)

invalid_shapes = (
(),
(4,),
)
yield SampleInput(
make_arg((4, 16)), 5.6
), RuntimeError, "offset must have type Int",
for shape in invalid_shapes:
yield SampleInput(
make_arg(shape),
), RuntimeError, f"input tensor for triu must have 2 or more dims, but got {len(shape)} dims",
15 changes: 15 additions & 0 deletions tests/python/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
matmul_input_generator,
linear_input_generator,
linear_error_generator,
triu_input_generator,
triu_error_generator,
)
from utils import (
bool_int_dtypes,
Expand Down Expand Up @@ -1218,6 +1220,18 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
)
linear_ops.append(linear_opinfo)

tv_val_ops = []

triu_opinfo = OpInfo(
lambda fd: fd.ops.triu,
"triu",
sample_input_generator=triu_input_generator,
error_input_generator=triu_error_generator,
reference=torch.triu,
)

tv_val_ops.append(triu_opinfo)

""" End Tensor Creation """

# Puts all opinfos into the "opinfos" list
Expand All @@ -1231,3 +1245,4 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
opinfos.extend(tensor_creation_ops)
opinfos.extend(matmul_ops)
opinfos.extend(linear_ops)
opinfos.extend(tv_val_ops)
15 changes: 15 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,21 @@ def fusion_func(fd: FusionDefinition):
self.assertEqual(eager_out2, nvf_out[1])
# self.assertEqual(eager_out3, nvf_out[2])

def test_triu(self):
inputs = [
torch.randn(4, 16, device="cuda", dtype=torch.float16),
]

def fusion_func(fd: FusionDefinition) -> None:
t0 = fd.from_pytorch(inputs[0])
tt = fd.define_scalar(-1, dtype=DataType.Int)
t1 = fd.ops.triu(t0, tt)
fd.add_output(t1)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out0 = torch.triu(inputs[0], -1)
self.assertEqual(eager_out0, nvf_out[0])

def test_complex_rsqrt(self):
inputs = [
torch.randn(4, device="cuda", dtype=torch.complex64),
Expand Down

0 comments on commit 130551d

Please sign in to comment.