Skip to content

Commit

Permalink
Include values used in regions
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Oct 22, 2024
1 parent 89d8862 commit 6dbb655
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 22 deletions.
35 changes: 19 additions & 16 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,24 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
return true;
}

static bool usesValuesDefinedAbove(Operation *op) {
bool usesValuesFromAbove = false;
mlir::visitUsedValuesDefinedAbove(
op->getRegions(), [&](void *) { usesValuesFromAbove = true; });
return usesValuesFromAbove;
void getBackwardSliceIncludingUsesFromAbove(
Operation *op, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
BackwardSliceOptions wrappedOptions(options);
wrappedOptions.filter = [&](Operation *op) {
if (!options.filter(op)) {
return false;
}
BackwardSliceOptions regionOptions(wrappedOptions);
regionOptions.inclusive = true;
mlir::visitUsedValuesDefinedAbove(
op->getRegions(), [&](OpOperand *operand) {
getBackwardSlice(operand->get(), backwardSlice, regionOptions);
});
return true;
};

getBackwardSlice(op, backwardSlice, wrappedOptions);
}

bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
Expand All @@ -114,23 +127,13 @@ bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
assert(dominanceInfo.properlyDominates(seedOp, op) &&
op->getParentRegion() == seedOp->getParentRegion());

if (usesValuesDefinedAbove(op)) {
return false;
}

BackwardSliceOptions options;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);

// `getBackwardSlice` doesnt track uses from within an ops region, so make
// sure there are no values defined above.
if (llvm::any_of(slice, usesValuesDefinedAbove)) {
return false;
}
getBackwardSliceIncludingUsesFromAbove(op, &slice, options);

return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
Expand Down
15 changes: 9 additions & 6 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo, Operation *seedOp);

/// Wrapps `filter` so that operations used in a region of an op also get
/// included in the backward slice. This breaks the topological sorting of the
/// original `getBackwardsSlice` but isn't nessasary for the uses here.
/// TODO: Upstream this as a part of `getBackwardsSlice`.
void getBackwardSliceIncludingUsesFromAbove(
Operation *op, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options);

/// Moves the operands and transitive defs for each op in `operations` directly
/// after `insertionPoint`. Note: this does not check if it is legal to move the
/// operands.
Expand All @@ -44,16 +52,11 @@ moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
!ignoreOperationsSet.contains(op);
};
// Set inclusive to true cause the slice is computed from the operand, and
// we want to include the defining op (which is the point here)
options.inclusive = true;

llvm::SetVector<Operation *> slice;
for (auto op : operations) {
assert(insertionPoint->getBlock() == op->getBlock());
for (auto operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
getBackwardSliceIncludingUsesFromAbove(op, &slice, options);
}

mlir::topologicalSort(slice);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,58 @@ util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<
// CHECK-LABEL: util.func public @fuse_by_moving_consumer
// CHECK: linalg.generic
// CHECK-NOT: linalg.generic

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @dont_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f32
%cst_1 = arith.constant 3.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %cst : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%c2 = arith.constant 2 : index
%extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
%2 = arith.addf %extracted, %extracted : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
}

// CHECK-LABEL: util.func public @dont_fuse_use_from_above
// CHECK: linalg.generic
// CHECK: linalg.generic


// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @do_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f32
%cst_1 = arith.constant 3.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %cst : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%c2 = arith.constant 2 : index
%extracted = tensor.extract %arg0[%c2, %c2] : tensor<5x5xf32>
%2 = arith.addf %extracted, %extracted : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
}

// CHECK-LABEL: util.func public @do_fuse_use_from_above
// CHECK: linalg.generic
// CHECK-NOT: linalg.generic

0 comments on commit 6dbb655

Please sign in to comment.