|
| 1 | +// DEFINE: %{compile} = mlir-opt %s \ |
| 2 | +// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule |\ |
| 3 | +// DEFINE: mlir-opt \ |
| 4 | +// DEFINE: -test-lower-to-llvm -o %t |
| 5 | +// DEFINE: %{entry_point} = main |
| 6 | +// DEFINE: %{run} = mlir-cpu-runner %t -e %{entry_point} -entry-point-result=void \ |
| 7 | +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils |
| 8 | + |
| 9 | +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s |
| 10 | + |
| 11 | +/// End-to-end test for tensor.unpack where one of the inner tile sizes is |
| 12 | +/// dynamic. See pack-dynamic-inner-tile.mlir for a similar test for tensor.pack. |
| 13 | + |
| 14 | +func.func @main() { |
| 15 | + // Allocate and initialise the inputs |
| 16 | + %A_alloc = tensor.empty() : tensor<7x3xi32> |
| 17 | + |
| 18 | + %A = arith.constant dense<[ |
| 19 | + [[[1], |
| 20 | + [2], |
| 21 | + [3], |
| 22 | + [4], |
| 23 | + [5], |
| 24 | + [6], |
| 25 | + [7], |
| 26 | + [123]], |
| 27 | + [[8], |
| 28 | + [9], |
| 29 | + [10], |
| 30 | + [11], |
| 31 | + [12], |
| 32 | + [13], |
| 33 | + [14], |
| 34 | + [123]], |
| 35 | + [[15], |
| 36 | + [16], |
| 37 | + [17], |
| 38 | + [18], |
| 39 | + [19], |
| 40 | + [20], |
| 41 | + [21], |
| 42 | + [123]]] |
| 43 | + ]> : tensor<1x3x8x1xi32> |
| 44 | + |
| 45 | + %A_cast = tensor.cast %A : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32> |
| 46 | + func.call @unpack(%A_cast) : (tensor<?x3x?x1xi32>) -> () |
| 47 | + |
| 48 | + return |
| 49 | +} |
| 50 | + |
| 51 | +func.func private @unpack(%A: tensor<?x3x?x1xi32>) { |
| 52 | + %c1 = arith.constant 1 : index |
| 53 | + %pad_val = arith.constant 123 : i32 |
| 54 | + |
| 55 | + // Dynamic tile size |
| 56 | + %tile_size = arith.constant 8 : index |
| 57 | + %A_unpack_empty = tensor.empty() : tensor<7x3xi32> |
| 58 | + |
| 59 | + %A_unpack = tensor.unpack %A |
| 60 | + inner_dims_pos = [0, 1] |
| 61 | + inner_tiles = [%tile_size, 1] |
| 62 | + into %A_unpack_empty : tensor<?x3x?x1xi32> -> tensor<7x3xi32> |
| 63 | + %A_cast = tensor.cast %A_unpack : tensor<7x3xi32> to tensor<*xi32> |
| 64 | + |
| 65 | + // Print the results |
| 66 | + // CHECK: Unranked Memref base@ = 0x{{.*}} rank = 2 offset = 0 sizes = [7, 3] strides = [3, 1] data = |
| 67 | + // CHECK-NEXT: [1, 8, 15], |
| 68 | + // CHECK-NEXT: [2, 9, 16], |
| 69 | + // CHECK-NEXT: [3, 10, 17], |
| 70 | + // CHECK-NEXT: [4, 11, 18], |
| 71 | + // CHECK-NEXT: [5, 12, 19], |
| 72 | + // CHECK-NEXT: [6, 13, 20], |
| 73 | + // CHECK-NEXT: [7, 14, 21] |
| 74 | + call @printMemrefI32(%A_cast) : (tensor<*xi32>) -> () |
| 75 | + |
| 76 | + return |
| 77 | +} |
| 78 | + |
| 79 | +module @transforms attributes { transform.with_named_sequence } { |
| 80 | + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consume}) { |
| 81 | + %pack = transform.structured.match ops{["tensor.unpack"]} in %module : (!transform.any_op) -> !transform.any_op |
| 82 | + |
| 83 | + // 1. Tile so that we can decompose tensor.pack |
| 84 | + // Ops (see step 2) |
| 85 | + %c8 = transform.param.constant 8 : i64 -> !transform.param<i64> |
| 86 | + %tiled_pack_op_p, %loops:2 = transform.structured.tile_using_for %pack tile_sizes [%c8, 1] |
| 87 | + : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op) |
| 88 | + |
| 89 | + // 2. Decompose the tiled unpack Op into tensor.extract_slice + tensor.insert_slice: |
| 90 | + %func_op = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> |
| 91 | + transform.apply_patterns to %func_op { |
| 92 | + transform.apply_patterns.linalg.decompose_pack_unpack |
| 93 | + transform.apply_patterns.linalg.decompose_pad |
| 94 | + } : !transform.op<"func.func"> |
| 95 | + |
| 96 | + // 3. Bufferize before lowering to LLVM |
| 97 | + %bufferize = transform.bufferization.one_shot_bufferize %module |
| 98 | + {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op |
| 99 | + |
| 100 | + // 4. Canonicalize |
| 101 | + %func_op_bufferized = transform.structured.match ops{["func.func"]} in %bufferize : (!transform.any_op) -> !transform.op<"func.func"> |
| 102 | + transform.apply_patterns to %func_op_bufferized { |
| 103 | + transform.apply_patterns.canonicalization |
| 104 | + } : !transform.op<"func.func"> |
| 105 | + |
| 106 | + transform.yield |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +func.func private @printMemrefI32(%ptr : tensor<*xi32>) |
0 commit comments