Skip to content

Commit a3ab248

Browse files
committed
adding python bindings for triu
1 parent 5f82c65 commit a3ab248

File tree

7 files changed

+114
-2
lines changed

7 files changed

+114
-2
lines changed

csrc/ops/arith.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,9 @@ TensorView* triu(TensorView* tv, Val* offset) {
500500

501501
NVF_CHECK(
502502
tv->nDims() >= 2,
503-
"triu is only supported for 2+D tensors, but got ",
503+
"input tensor for triu must have 2 or more dims, but got ",
504504
tv->nDims(),
505-
"D tensor");
505+
" dims");
506506

507507
// Let's say we want a triu of a 2D tensor of shape [2, 4]
508508
// We broadcast the iota of the outer dim
@@ -539,6 +539,10 @@ TensorView* triu(TensorView* tv, Val* offset) {
539539
return where(mask, tv, IrBuilder::create<Val>(0, tv->dtype()));
540540
}
541541

542+
TensorView* triu(TensorView* tv) {
543+
return triu(tv, IrBuilder::create<Val>(0));
544+
}
545+
542546
// UNARY OPERATIONS
543547

544548
#define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \

csrc/ops/arith.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ NVF_API TensorView* arange(
252252
NVF_API TensorView* eye(Val* size, DataType dtype);
253253
NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype);
254254
NVF_API TensorView* triu(TensorView* tv, Val* offset);
255+
NVF_API TensorView* triu(TensorView* tv);
255256

256257
// UNARY OPERATIONS
257258
// abs

csrc/python_frontend/python_bindings.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,38 @@ void initNvFuserPythonBindings(PyObject* module) {
16931693
NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
16941694
#undef NVFUSER_PYTHON_BINDING_UNARY_OP
16951695

1696+
nvf_ops.def(
1697+
"triu",
1698+
[](FusionDefinition::Operators& self,
1699+
Tensor input,
1700+
std::optional<Scalar> offset) -> Tensor {
1701+
FUSER_PERF_SCOPE("Operators.triu");
1702+
NVF_CHECK(
1703+
self.validUse(), "Attempting to add to a completed definition!");
1704+
FusionDefinition* fd = self.fusion_definition;
1705+
Tensor output = fd->defineTensor(input.dims);
1706+
if (offset.has_value()) {
1707+
fd->defineRecord(new OpRecord<TensorView*, TensorView*, Val*>(
1708+
{fd->recordingState(input()),
1709+
fd->recordingState(offset.value()())},
1710+
{fd->recordingState(output())},
1711+
("ops.triu"),
1712+
serde::RecordType::Binary_TV_VAL,
1713+
static_cast<TensorView* (*)(TensorView*, Val*)>(triu)));
1714+
} else {
1715+
fd->defineRecord(new OpRecord<TensorView*, TensorView*>(
1716+
{fd->recordingState(input())},
1717+
{fd->recordingState(output())},
1718+
("ops.triu"),
1719+
serde::RecordType::Unary_TV,
1720+
static_cast<TensorView* (*)(TensorView*)>(triu)));
1721+
}
1722+
return output;
1723+
},
1724+
py::arg("input"),
1725+
py::arg("offset") = std::nullopt,
1726+
py::return_value_policy::reference);
1727+
16961728
// overload to
16971729
nvf_ops.def(
16981730
"stride_order",

csrc/serde/fusion_record.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,13 @@ void RecordFunctorFactory::setupFunctionMaps() {
653653
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
654654
unary_val.emplace(("ops." op_str), static_cast<Val* (*)(Val*)>(op_name));
655655

656+
#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \
657+
unary_tv.emplace( \
658+
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
659+
binary_tv_val.emplace( \
660+
("ops." op_str), \
661+
static_cast<TensorView* (*)(TensorView*, Val*)>(op_name));
662+
656663
#define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \
657664
binary_tv.emplace( \
658665
("ops." op_str), \
@@ -808,6 +815,8 @@ void RecordFunctorFactory::setupFunctionMaps() {
808815
NVFUSER_UNARY_TV_OP("real", real)
809816
NVFUSER_UNARY_TV_OP("imag", imag)
810817

818+
NVFUSER_UNARY_TV_ALPHA_OP("triu", triu)
819+
811820
NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul)
812821
NVFUSER_BINARY_TV_ONLY_OP("linear", linear)
813822
NVFUSER_TERNARY_TV_ONLY_OP("linear", linear)

tests/python/opinfo_input_generators.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,8 @@ def elementwise_unary_generator(
638638

639639
# Typical inputs
640640
for shape in shapes:
641+
if op.name == "triu" and len(shape) < 2:
642+
continue
641643
yield SampleInput(make_arg(shape))
642644
yield SampleInput(make_arg(shape, noncontiguous=True))
643645

@@ -1591,3 +1593,37 @@ def div_input_generator(
15911593
denom = torch.where(denom_is_small, denom_scaled_to_minabs, denom).detach()
15921594
denom.requires_grad_(requires_grad)
15931595
yield SampleInput(numer, denom)
1596+
1597+
1598+
def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
1599+
offsets = (0, 1, -1, 2, 3, -3, 1024, -1024)
1600+
1601+
for element in elementwise_unary_generator(
1602+
op,
1603+
dtype,
1604+
requires_grad,
1605+
enable_extremal_value_testing=False,
1606+
enable_large_value_testing=False,
1607+
enable_small_value_testing=False,
1608+
):
1609+
yield element
1610+
for offset in offsets:
1611+
yield SampleInput(*element.args, offset)
1612+
1613+
1614+
def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
1615+
make_arg = partial(
1616+
make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad
1617+
)
1618+
1619+
invalid_shapes = (
1620+
(),
1621+
(4,),
1622+
)
1623+
yield SampleInput(
1624+
make_arg((4, 16)), 5.6
1625+
), RuntimeError, "offset must have type Int",
1626+
for shape in invalid_shapes:
1627+
yield SampleInput(
1628+
make_arg(shape),
1629+
), RuntimeError, f"input tensor for triu must have 2 or more dims, but got {len(shape)} dims",

tests/python/opinfos.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
matmul_input_generator,
5151
linear_input_generator,
5252
linear_error_generator,
53+
triu_input_generator,
54+
triu_error_generator,
5355
)
5456
from utils import (
5557
bool_int_dtypes,
@@ -1218,6 +1220,18 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
12181220
)
12191221
linear_ops.append(linear_opinfo)
12201222

1223+
tv_val_ops = []
1224+
1225+
triu_opinfo = OpInfo(
1226+
lambda fd: fd.ops.triu,
1227+
"triu",
1228+
sample_input_generator=triu_input_generator,
1229+
error_input_generator=triu_error_generator,
1230+
reference=torch.triu,
1231+
)
1232+
1233+
tv_val_ops.append(triu_opinfo)
1234+
12211235
""" End Tensor Creation """
12221236

12231237
# Puts all opinfos into the "opinfos" list
@@ -1231,3 +1245,4 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
12311245
opinfos.extend(tensor_creation_ops)
12321246
opinfos.extend(matmul_ops)
12331247
opinfos.extend(linear_ops)
1248+
opinfos.extend(tv_val_ops)

tests/python/test_python_frontend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,21 @@ def fusion_func(fd: FusionDefinition):
12041204
self.assertEqual(eager_out2, nvf_out[1])
12051205
# self.assertEqual(eager_out3, nvf_out[2])
12061206

1207+
def test_triu(self):
1208+
inputs = [
1209+
torch.randn(4, 16, device="cuda", dtype=torch.float16),
1210+
]
1211+
1212+
def fusion_func(fd: FusionDefinition) -> None:
1213+
t0 = fd.from_pytorch(inputs[0])
1214+
tt = fd.define_scalar(-1, dtype=DataType.Int)
1215+
t1 = fd.ops.triu(t0, tt)
1216+
fd.add_output(t1)
1217+
1218+
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
1219+
eager_out0 = torch.triu(inputs[0], -1)
1220+
self.assertEqual(eager_out0, nvf_out[0])
1221+
12071222
def test_complex_rsqrt(self):
12081223
inputs = [
12091224
torch.randn(4, device="cuda", dtype=torch.complex64),

0 commit comments

Comments
 (0)