@@ -638,6 +638,8 @@ def elementwise_unary_generator(
638
638
639
639
# Typical inputs
640
640
for shape in shapes :
641
+ if op .name == "triu" and len (shape ) < 2 :
642
+ continue
641
643
yield SampleInput (make_arg (shape ))
642
644
yield SampleInput (make_arg (shape , noncontiguous = True ))
643
645
@@ -1591,3 +1593,37 @@ def div_input_generator(
1591
1593
denom = torch .where (denom_is_small , denom_scaled_to_minabs , denom ).detach ()
1592
1594
denom .requires_grad_ (requires_grad )
1593
1595
yield SampleInput (numer , denom )
1596
+
1597
+
1598
+ def triu_input_generator (op : OpInfo , dtype : torch .dtype , requires_grad : bool = False ):
1599
+ offsets = (0 , 1 , - 1 , 2 , 3 , - 3 , 1024 , - 1024 )
1600
+
1601
+ for element in elementwise_unary_generator (
1602
+ op ,
1603
+ dtype ,
1604
+ requires_grad ,
1605
+ enable_extremal_value_testing = False ,
1606
+ enable_large_value_testing = False ,
1607
+ enable_small_value_testing = False ,
1608
+ ):
1609
+ yield element
1610
+ for offset in offsets :
1611
+ yield SampleInput (* element .args , offset )
1612
+
1613
+
1614
+ def triu_error_generator (op : OpInfo , dtype : torch .dtype , requires_grad : bool = False ):
1615
+ make_arg = partial (
1616
+ make_tensor , device = "cuda" , dtype = dtype , requires_grad = requires_grad
1617
+ )
1618
+
1619
+ invalid_shapes = (
1620
+ (),
1621
+ (4 ,),
1622
+ )
1623
+ yield SampleInput (
1624
+ make_arg ((4 , 16 )), 5.6
1625
+ ), RuntimeError , "offset must have type Int" ,
1626
+ for shape in invalid_shapes :
1627
+ yield SampleInput (
1628
+ make_arg (shape ),
1629
+ ), RuntimeError , f"input tensor for triu must have 2 or more dims, but got { len (shape )} dims" ,
0 commit comments