|
| 1 | +// RUN: enzymemlir-opt --split-input-file --hoist-enzyme-regions %s | FileCheck %s |
| 2 | +// CHECK-LABEL: func.func @foo |
| 3 | +// CHECK-SAME: (%arg0: f64, %arg1: f64, %arg2: f64) -> f64 |
| 4 | +// CHECK: %c10 = arith.constant 10 : index |
| 5 | +// CHECK: %c1 = arith.constant 1 : index |
| 6 | +// CHECK: %cst = arith.constant 2.500000e+00 : f64 |
| 7 | +// CHECK: %cst_0 = arith.constant 2.000000e+00 : f64 |
| 8 | +// CHECK: %cst_1 = arith.constant 0.000000e+00 : f64 |
| 9 | +// CHECK: %cst_2 = arith.constant 1.000000e+02 : f64 |
| 10 | +// CHECK: %0 = arith.mulf %arg2, %cst_0 : f64 |
| 11 | +// CHECK: %1 = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %cst) -> (f64) { |
| 12 | +// CHECK: %{{.*}} = arith.mulf %{{.*}}, %cst_2 : f64 |
| 13 | +// CHECK: %{{.*}} = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %{{.*}}) -> (f64) { |
| 14 | +// CHECK: %{{.*}} = arith.addf %{{.*}}, %0 : f64 |
| 15 | +// CHECK: scf.yield %{{.*}} : f64 |
| 16 | +// CHECK: } |
| 17 | +// CHECK: scf.yield %{{.*}} : f64 |
| 18 | +// CHECK: } |
| 19 | +// CHECK: %2 = enzyme.autodiff_region(%arg0, %arg1) { |
| 20 | +// CHECK: ^bb0(%{{.*}}: f64): |
| 21 | +// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 |
| 22 | +// CHECK: %{{.*}} = arith.mulf %{{.*}}, %0 : f64 |
| 23 | +// CHECK: %{{.*}} = arith.mulf %{{.*}}, %1 : f64 |
| 24 | +// CHECK: %{{.*}} = arith.addf %{{.*}}, %cst_1 : f64 |
| 25 | +// CHECK: %{{.*}} = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %{{.*}}) -> (f64) { |
| 26 | +// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 |
| 27 | +// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 |
| 28 | +// CHECK: scf.yield %{{.*}} : f64 |
| 29 | +// CHECK: } |
| 30 | +// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 |
| 31 | +// CHECK: enzyme.yield %{{.*}} : f64 |
| 32 | +// CHECK: } attributes {{.*}} : (f64, f64) -> f64 |
| 33 | + |
| 34 | +func.func @foo(%arg0: f64, %arg1: f64,%xx: f64) -> f64 { |
| 35 | + |
| 36 | + %yy_cst = arith.constant 100.0 : f64 |
| 37 | + %0 = enzyme.autodiff_region(%arg0, %arg1) { |
| 38 | + ^bb0(%arg2: f64): |
| 39 | + // hoistable constant ops |
| 40 | + %c0 = arith.constant 0.0 : f64 |
| 41 | + %c1 = arith.constant 1.0 : f64 |
| 42 | + %c2 = arith.constant 2.0 : f64 |
| 43 | + %cx = arith.mulf %c2, %xx : f64 |
| 44 | + |
| 45 | + %sq = arith.mulf %arg2, %arg2 : f64 |
| 46 | + %sqx = arith.mulf %sq, %cx : f64 |
| 47 | + |
| 48 | + // hoistable loops |
| 49 | + %yy0 = arith.constant 2.5 : f64 |
| 50 | + %one = arith.constant 1 : index |
| 51 | + %ten = arith.constant 10 : index |
| 52 | + %yy = scf.for %iv = %one to %ten step %one iter_args(%yy_iter = %yy0) -> (f64) { |
| 53 | + %tm = arith.mulf %yy_iter, %yy_cst : f64 |
| 54 | + %ta = scf.for %jv = %one to %ten step %one iter_args(%tm_iter = %tm) -> (f64) { |
| 55 | + %ta = arith.addf %tm, %cx : f64 |
| 56 | + scf.yield %ta : f64 |
| 57 | + } |
| 58 | + scf.yield %ta : f64 |
| 59 | + } |
| 60 | + |
| 61 | + %sqxy = arith.mulf %sqx, %yy : f64 |
| 62 | + %zz0 = arith.addf %sqx, %c0 : f64 |
| 63 | + %zz = scf.for %iv = %one to %ten step %one iter_args(%zz_iter = %zz0) ->(f64) { |
| 64 | + %zm = arith.addf %zz_iter, %sqx : f64 |
| 65 | + %zout = arith.mulf %zm, %zz_iter : f64 |
| 66 | + scf.yield %zout : f64 |
| 67 | + } |
| 68 | + |
| 69 | + %sqxyz = arith.mulf %zz, %sqxy : f64 |
| 70 | + enzyme.yield %sqxyz : f64 |
| 71 | + } attributes {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64 |
| 72 | + return %0 : f64 |
| 73 | +} |
0 commit comments