Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python bindings and tests for Triu #3637

Open
wants to merge 2 commits into
base: pbasu_iota_experiment
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ TensorView* triu(TensorView* tv, Val* offset) {

NVF_CHECK(
dims >= 2,
"triu is only supported for 2+D tensors, but got ",
"input tensor for triu must have 2 or more dims, but got ",
dims,
"D tensor");
" dims");

auto tv_rows = iota(
tv_logical_no_reductions[dims - 2]->extent(),
Expand All @@ -108,6 +108,10 @@ TensorView* triu(TensorView* tv, Val* offset) {
return where(mask, tv, IrBuilder::create<Val>(0, tv->dtype()));
}

TensorView* triu(TensorView* tv) {
return triu(tv, IrBuilder::create<Val>(0));
}

namespace {

TensorView* newForLinear(
Expand Down
1 change: 1 addition & 0 deletions csrc/ops/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ NVF_API TensorView* dropout_backward(
Val* scale);

NVF_API TensorView* triu(TensorView* tv, Val* offset);
NVF_API TensorView* triu(TensorView* tv);

struct LstmResult {
TensorView* cell = nullptr;
Expand Down
32 changes: 32 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,38 @@ void initNvFuserPythonBindings(PyObject* module) {
NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
#undef NVFUSER_PYTHON_BINDING_UNARY_OP

nvf_ops.def(
"triu",
[](FusionDefinition::Operators& self,
Tensor input,
std::optional<Scalar> offset) -> Tensor {
FUSER_PERF_SCOPE("Operators.triu");
NVF_CHECK(
self.validUse(), "Attempting to add to a completed definition!");
FusionDefinition* fd = self.fusion_definition;
Tensor output = fd->defineTensor(input.dims);
if (offset.has_value()) {
fd->defineRecord(new OpRecord<TensorView*, TensorView*, Val*>(
{fd->recordingState(input()),
fd->recordingState(offset.value()())},
{fd->recordingState(output())},
("ops.triu"),
serde::RecordType::Binary_TV_VAL,
static_cast<TensorView* (*)(TensorView*, Val*)>(triu)));
} else {
fd->defineRecord(new OpRecord<TensorView*, TensorView*>(
{fd->recordingState(input())},
{fd->recordingState(output())},
("ops.triu"),
serde::RecordType::Unary_TV,
static_cast<TensorView* (*)(TensorView*)>(triu)));
}
return output;
},
py::arg("input"),
py::arg("offset") = std::nullopt,
py::return_value_policy::reference);

// overload to
nvf_ops.def(
"stride_order",
Expand Down
9 changes: 9 additions & 0 deletions csrc/serde/fusion_record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,13 @@ void RecordFunctorFactory::setupFunctionMaps() {
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
unary_val.emplace(("ops." op_str), static_cast<Val* (*)(Val*)>(op_name));

#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \
unary_tv.emplace( \
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
binary_tv_val.emplace( \
("ops." op_str), \
static_cast<TensorView* (*)(TensorView*, Val*)>(op_name));

#define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \
binary_tv.emplace( \
("ops." op_str), \
Expand Down Expand Up @@ -808,6 +815,8 @@ void RecordFunctorFactory::setupFunctionMaps() {
NVFUSER_UNARY_TV_OP("real", real)
NVFUSER_UNARY_TV_OP("imag", imag)

NVFUSER_UNARY_TV_ALPHA_OP("triu", triu)

NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul)
NVFUSER_BINARY_TV_ONLY_OP("linear", linear)
NVFUSER_TERNARY_TV_ONLY_OP("linear", linear)
Expand Down
38 changes: 38 additions & 0 deletions tests/python/opinfo_input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,3 +1591,41 @@ def div_input_generator(
denom = torch.where(denom_is_small, denom_scaled_to_minabs, denom).detach()
denom.requires_grad_(requires_grad)
yield SampleInput(numer, denom)


def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
offsets = (0, 1, -1, 2, 3, -3, 1024, -1024)

for element in elementwise_unary_generator(
op,
dtype,
requires_grad,
enable_extremal_value_testing=False,
enable_large_value_testing=False,
enable_small_value_testing=False,
):
if element.args[0].ndim < 2:
continue
yield element
for offset in offsets:
yield SampleInput(*element.args, offset)


def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
make_arg = partial(
make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad
)

yield SampleInput(
make_arg((4, 16)), 5.6
), RuntimeError, "offset must have integral type"

invalid_shapes = (
(),
(4,),
)

for shape in invalid_shapes:
yield SampleInput(
make_arg(shape),
), RuntimeError, f"input tensor for triu must have 2 or more dims, but got {len(shape)} dims"
15 changes: 15 additions & 0 deletions tests/python/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
matmul_input_generator,
linear_input_generator,
linear_error_generator,
triu_input_generator,
triu_error_generator,
)
from utils import (
bool_int_dtypes,
Expand Down Expand Up @@ -1218,6 +1220,18 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
)
linear_ops.append(linear_opinfo)

tv_val_ops = []

triu_opinfo = OpInfo(
lambda fd: fd.ops.triu,
"triu",
sample_input_generator=triu_input_generator,
error_input_generator=triu_error_generator,
reference=torch.triu,
)

tv_val_ops.append(triu_opinfo)

""" End Tensor Creation """

# Puts all opinfos into the "opinfos" list
Expand All @@ -1231,3 +1245,4 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
opinfos.extend(tensor_creation_ops)
opinfos.extend(matmul_ops)
opinfos.extend(linear_ops)
opinfos.extend(tv_val_ops)
15 changes: 15 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,21 @@ def fusion_func(fd: FusionDefinition):
self.assertEqual(eager_out2, nvf_out[1])
# self.assertEqual(eager_out3, nvf_out[2])

def test_triu(self):
inputs = [
torch.randn(4, 16, device="cuda", dtype=torch.float16),
]

def fusion_func(fd: FusionDefinition) -> None:
t0 = fd.from_pytorch(inputs[0])
tt = fd.define_scalar(-1, dtype=DataType.Int)
t1 = fd.ops.triu(t0, tt)
fd.add_output(t1)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out0 = torch.triu(inputs[0], -1)
self.assertEqual(eager_out0, nvf_out[0])

def test_complex_rsqrt(self):
inputs = [
torch.randn(4, device="cuda", dtype=torch.complex64),
Expand Down
Loading