@@ -64,3 +64,45 @@ func.func @main3(%arg0: tensor<2x2xcomplex<f64>> {enzymexla.memory_effects = []}
6464// CHECK-NEXT: %8 = chlo.conj %7 : tensor<2x2xcomplex<f64>> -> tensor<2x2xcomplex<f64>>
6565// CHECK-NEXT: return %8, %4, %0 : tensor<2x2xcomplex<f64>>, tensor<2x2xcomplex<f64>>, tensor<2x2xcomplex<f64>>
6666// CHECK-NEXT: }
67+
68+ func.func @main4 (%arg0: tensor <2 x16 xf32 > {enzymexla.memory_effects = []}, %arg1: tensor <16 xf32 > {enzymexla.memory_effects = []}, %arg2: tensor <16 x16 xf32 > {enzymexla.memory_effects = []}, %arg3: tensor <16 xf32 > {enzymexla.memory_effects = []}, %arg4: tensor <16 x1 xf32 > {enzymexla.memory_effects = []}, %arg5: tensor <1 xf32 > {enzymexla.memory_effects = []}, %arg6: tensor <2 xf32 > {enzymexla.memory_effects = []}) -> tensor <2 xf32 > attributes {enzymexla.memory_effects = []} {
69+ %cst = stablehlo.constant dense <1.000000e+00 > : tensor <1 xf32 >
70+ %cst_0 = stablehlo.constant dense <1.000000e+00 > : tensor <16 xf32 >
71+ %0 = stablehlo.reshape %arg4 : (tensor <16 x1 xf32 >) -> tensor <1 x16 xf32 >
72+ %1 = stablehlo.dot_general %arg0 , %arg6 , contracting_dims = [0 ] x [0 ], precision = [DEFAULT , DEFAULT ] : (tensor <2 x16 xf32 >, tensor <2 xf32 >) -> tensor <16 xf32 >
73+ %2 = stablehlo.add %1 , %arg1 : tensor <16 xf32 >
74+ %3 = stablehlo.tanh %2 : tensor <16 xf32 >
75+ %4 = stablehlo.dot_general %arg2 , %3 , contracting_dims = [0 ] x [0 ], precision = [DEFAULT , DEFAULT ] : (tensor <16 x16 xf32 >, tensor <16 xf32 >) -> tensor <16 xf32 >
76+ %5 = stablehlo.add %4 , %arg3 : tensor <16 xf32 >
77+ %6 = stablehlo.dot_general %cst , %0 , contracting_dims = [0 ] x [0 ], precision = [DEFAULT , DEFAULT ] : (tensor <1 xf32 >, tensor <1 x16 xf32 >) -> tensor <16 xf32 >
78+ %7 = stablehlo.tanh %5 : tensor <16 xf32 >
79+ %8 = stablehlo.multiply %7 , %7 : tensor <16 xf32 >
80+ %9 = stablehlo.subtract %cst_0 , %8 : tensor <16 xf32 >
81+ %10 = stablehlo.multiply %6 , %9 : tensor <16 xf32 >
82+ %11 = stablehlo.dot_general %10 , %arg2 , contracting_dims = [0 ] x [1 ], precision = [DEFAULT , DEFAULT ] : (tensor <16 xf32 >, tensor <16 x16 xf32 >) -> tensor <16 xf32 >
83+ %12 = stablehlo.multiply %3 , %3 : tensor <16 xf32 >
84+ %13 = stablehlo.subtract %cst_0 , %12 : tensor <16 xf32 >
85+ %14 = stablehlo.multiply %11 , %13 : tensor <16 xf32 >
86+ %15 = stablehlo.dot_general %14 , %arg0 , contracting_dims = [0 ] x [1 ], precision = [DEFAULT , DEFAULT ] : (tensor <16 xf32 >, tensor <2 x16 xf32 >) -> tensor <2 xf32 >
87+ return %15 : tensor <2 xf32 >
88+ }
89+
90+ // CHECK: func.func @main4(%arg0: tensor<2x16xf32> {enzymexla.memory_effects = []}, %arg1: tensor<16xf32> {enzymexla.memory_effects = []}, %arg2: tensor<16x16xf32> {enzymexla.memory_effects = []}, %arg3: tensor<16xf32> {enzymexla.memory_effects = []}, %arg4: tensor<16x1xf32> {enzymexla.memory_effects = []}, %arg5: tensor<1xf32> {enzymexla.memory_effects = []}, %arg6: tensor<2xf32> {enzymexla.memory_effects = []}) -> tensor<2xf32> attributes {enzymexla.memory_effects = []} {
91+ // CHECK-NEXT: %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
92+ // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg6, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x16xf32>, tensor<2xf32>) -> tensor<16xf32>
93+ // CHECK-NEXT: %1 = stablehlo.add %0, %arg1 : tensor<16xf32>
94+ // CHECK-NEXT: %2 = stablehlo.tanh %1 : tensor<16xf32>
95+ // CHECK-NEXT: %3 = stablehlo.dot_general %arg2, %2, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16xf32>) -> tensor<16xf32>
96+ // CHECK-NEXT: %4 = stablehlo.add %3, %arg3 : tensor<16xf32>
97+ // CHECK-NEXT: %5 = stablehlo.reshape %arg4 : (tensor<16x1xf32>) -> tensor<16xf32>
98+ // CHECK-NEXT: %6 = stablehlo.tanh %4 : tensor<16xf32>
99+ // CHECK-NEXT: %7 = stablehlo.multiply %6, %6 : tensor<16xf32>
100+ // CHECK-NEXT: %8 = stablehlo.subtract %cst, %7 : tensor<16xf32>
101+ // CHECK-NEXT: %9 = stablehlo.multiply %5, %8 : tensor<16xf32>
102+ // CHECK-NEXT: %10 = stablehlo.dot_general %9, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16xf32>, tensor<16x16xf32>) -> tensor<16xf32>
103+ // CHECK-NEXT: %11 = stablehlo.multiply %2, %2 : tensor<16xf32>
104+ // CHECK-NEXT: %12 = stablehlo.subtract %cst, %11 : tensor<16xf32>
105+ // CHECK-NEXT: %13 = stablehlo.multiply %10, %12 : tensor<16xf32>
106+ // CHECK-NEXT: %14 = stablehlo.dot_general %13, %arg0, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16xf32>, tensor<2x16xf32>) -> tensor<2xf32>
107+ // CHECK-NEXT: return %14 : tensor<2xf32>
108+ // CHECK-NEXT: }
0 commit comments