Skip to content

Commit 7f8be5d

Browse files
committed
test: elementwise loop fission
1 parent 5714270 commit 7f8be5d

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt --auto-batching --enzyme-hlo-opt %s | FileCheck %s
2+
3+
func.func @main(%arg0: tensor<10xf64>) -> tensor<10xf64> {
4+
%c = stablehlo.constant dense<1> : tensor<i32>
5+
%c_0 = stablehlo.constant dense<0> : tensor<i64>
6+
%c_1 = stablehlo.constant dense<10> : tensor<i64>
7+
%c_2 = stablehlo.constant dense<1> : tensor<i64>
8+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
9+
%0:2 = stablehlo.while(%iterArg = %c_0, %iterArg_3 = %cst) : tensor<i64>, tensor<10xf64> attributes {enzymexla.disable_min_cut}
10+
cond {
11+
%1 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
12+
stablehlo.return %1 : tensor<i1>
13+
} do {
14+
%1 = stablehlo.add %c_2, %iterArg : tensor<i64>
15+
%2 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32>
16+
%3 = stablehlo.subtract %2, %c : tensor<i32>
17+
%4 = stablehlo.dynamic_slice %arg0, %3, sizes = [1] : (tensor<10xf64>, tensor<i32>) -> tensor<1xf64>
18+
19+
%sin_res = stablehlo.sine %4 : tensor<1xf64>
20+
%neg_res = stablehlo.negate %sin_res : tensor<1xf64>
21+
%cos_res = stablehlo.cosine %4 : tensor<1xf64>
22+
%5 = stablehlo.add %neg_res, %cos_res : tensor<1xf64>
23+
24+
%6 = stablehlo.remainder %iterArg, %c_1 : tensor<i64>
25+
%7 = stablehlo.add %6, %c_2 : tensor<i64>
26+
%8 = stablehlo.convert %7 : (tensor<i64>) -> tensor<i32>
27+
%9 = stablehlo.subtract %8, %c : tensor<i32>
28+
%10 = stablehlo.dynamic_update_slice %iterArg_3, %5, %9 : (tensor<10xf64>, tensor<1xf64>, tensor<i32>) -> tensor<10xf64>
29+
30+
stablehlo.return %1, %10 : tensor<i64>, tensor<10xf64>
31+
}
32+
return %0#1 : tensor<10xf64>
33+
}
34+
35+
// CHECK: func.func @main(%arg0: tensor<10xf64>) -> tensor<10xf64> {
36+
// CHECK-NEXT: %0 = stablehlo.sine %arg0 : tensor<10xf64>
37+
// CHECK-NEXT: %1 = stablehlo.cosine %arg0 : tensor<10xf64>
38+
// CHECK-NEXT: %2 = stablehlo.subtract %1, %0 : tensor<10xf64>
39+
// CHECK-NEXT: return %2 : tensor<10xf64>
40+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)