Skip to content

Commit

Permalink
[Codegen][GPU] Add pass for fallback distribution (#18726)
Browse files Browse the repository at this point in the history
If any operation fails to tile or fuse in earlier pipeline stages we
need to have a fallback distribution step. This adds a pass for fallback
distribution that looks for all tilable ops not inside `scf.forall`
regions mapped to threads and greedily tiles + fuses starting from the
base of the dispatch.
  • Loading branch information
qedawkins authored Oct 13, 2024
1 parent 1f3382d commit d7378bb
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 1 deletion.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ iree_compiler_cc_library(
"GPUDistributeSharedMemoryCopy.cpp",
"GPUDistributionPatterns.cpp",
"GPUGeneralizeNamedOps.cpp",
"GPUGreedilyDistributeToThreads.cpp",
"GPUInferMemorySpace.cpp",
"GPULowerToUKernels.cpp",
"GPUMaterializeEncoding.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ iree_cc_library(
"GPUDistributeSharedMemoryCopy.cpp"
"GPUDistributionPatterns.cpp"
"GPUGeneralizeNamedOps.cpp"
"GPUGreedilyDistributeToThreads.cpp"
"GPUInferMemorySpace.cpp"
"GPULowerToUKernels.cpp"
"GPUMaterializeEncoding.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/TilingInterface.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_GPUGREEDILYDISTRIBUTETOTHREADSPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

namespace {

struct GPUGreedilyDistributeToThreadsPass final
: impl::GPUGreedilyDistributeToThreadsPassBase<
GPUGreedilyDistributeToThreadsPass> {
void runOnOperation() override;
};
} // namespace

/// Helper to tile and greedily fuse the given operation to threads. This uses
/// the iree_gpu.derived_thread_config logic internally to determine tile sizes
/// to use. This does not yield any fused operation and only replaces the tiling
/// root.
///
/// If tiling fails this returns silently (tiling is best effort). Later
/// verification steps will throw an error if distribution does not occur.
static void tileToThreads(RewriterBase &rewriter,
TilingInterface tilingInterfaceOp) {
rewriter.setInsertionPoint(tilingInterfaceOp);
auto configAttr =
IREE::GPU::DerivedThreadConfigAttr::get(rewriter.getContext());
SmallVector<OpFoldResult> tileSizes = configAttr.getTilingLevelSizes(
rewriter, llvm::to_underlying(IREE::GPU::TilingLevel::Thread),
tilingInterfaceOp);

// Pad the tile sizes with zero.
auto zero = rewriter.getIndexAttr(0);
int64_t numLoops = tilingInterfaceOp.getLoopIteratorTypes().size();
if (tileSizes.size() > numLoops) {
return;
}
while (tileSizes.size() < numLoops) {
tileSizes.push_back(zero);
}

scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);

// Use a "descending" relative thread mapping (ordered from slowest to
// fastest varying ID). For example, [linear_dim_n, ..., linear_dim_0].
SmallVector<Attribute> mapping;
int idx = 0;
for (auto size : tileSizes) {
if (!isConstantIntValue(size, 0)) {
unsigned mappingId =
static_cast<unsigned>(gpu::MappingId::LinearDim0) + idx++;
mapping.push_back(gpu::GPUThreadMappingAttr::get(
rewriter.getContext(), static_cast<gpu::MappingId>(mappingId)));
}
}
tilingOptions.setMapping(llvm::to_vector(llvm::reverse(mapping)));

scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);

scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand)
-> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
// Always fuse but never yield a replacement.
return scf::SCFTileAndFuseOptions::ControlFnResult{
/*yieldProducerReplacement=*/false};
};
tileAndFuseOptions.setFusionControlFn(controlFn);

FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
tileAndFuseOptions);
if (failed(tiledResults)) {
return;
}

// Perform the replacement of the tiling root.
for (OpResult res : tilingInterfaceOp->getResults()) {
if (auto replacement = tiledResults->replacements.lookup(res)) {
rewriter.replaceAllUsesWith(res, replacement);
}
}

if (tilingInterfaceOp->use_empty()) {
rewriter.eraseOp(tilingInterfaceOp);
}
}

/// Recursively process the given region and tile all tilable operations not
/// present within an `scf.forall` with thread/warp/lane mapping.
static void processRegion(RewriterBase &rewriter, Region *region) {
// Process the region blocks in reverse.
for (Block &block : llvm::reverse(region->getBlocks())) {
// Save a reversed list of operations within the block. Ops will be
// greedily tiled + fused in reverse so that if a producer can be fused
// with a consumer we only distribute the producer once via fusion.
SmallVector<Operation *> targetOps =
llvm::map_to_vector(llvm::reverse(block.getOperations()),
[](Operation &op) { return &op; });
// Skip all unused ops (possibly from tiling).
for (Operation *op : targetOps) {
if (op->use_empty()) {
continue;
}
// Skip all operations contained within an `scf.forall` mapped to threads
// warps or lanes. These are already distributed and fine to leave as is.
if (auto forall = dyn_cast<scf::ForallOp>(op)) {
if (forallOpHasMappingType<gpu::GPUThreadMappingAttr,
gpu::GPUWarpMappingAttr,
IREE::GPU::LaneIdAttr>(forall)) {
continue;
}
}

// If an op implements the tiling interface, try to greedily tile + fuse.
if (auto tilableOp = dyn_cast<TilingInterface>(op)) {
tileToThreads(rewriter, tilableOp);
continue;
}

// Else recursively process all nested operations.
for (auto &region : op->getRegions()) {
processRegion(rewriter, &region);
}
}
}
}

void GPUGreedilyDistributeToThreadsPass::runOnOperation() {
auto funcOp = getOperation();

IRRewriter rewriter(funcOp->getContext());
for (auto &region : funcOp->getRegions()) {
processRegion(rewriter, &region);
}
}

} // namespace mlir::iree_compiler
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ def GPUGeneralizeNamedOpsPass :
let summary = "Convert named Linalg ops to linalg.generic ops";
}

def GPUGreedilyDistributeToThreadsPass :
InterfacePass<"iree-codegen-gpu-greedily-distribute-to-threads", "mlir::FunctionOpInterface"> {
let summary = "Greedily distributes all remaining tilable ops to threads";
let dependentDialects = [
"::mlir::affine::AffineDialect",
"::mlir::gpu::GPUDialect",
"::mlir::scf::SCFDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
];
}

def GPUInferMemorySpacePass :
InterfacePass<"iree-codegen-gpu-infer-memory-space", "mlir::FunctionOpInterface"> {
let summary = "Pass to infer and set the memory space for all alloc_tensor ops.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_lit_test_suite(
"gpu_distribute_scf_for.mlir",
"gpu_distribute_shared_memory.mlir",
"gpu_generalize_named_ops.mlir",
"gpu_greedily_distribute_to_threads.mlir",
"gpu_infer_memory_space.mlir",
"gpu_lower_to_ukernels.mlir",
"gpu_combine_value_barriers.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_lit_test_suite(
"gpu_distribute_scf_for.mlir"
"gpu_distribute_shared_memory.mlir"
"gpu_generalize_named_ops.mlir"
"gpu_greedily_distribute_to_threads.mlir"
"gpu_infer_memory_space.mlir"
"gpu_lower_to_ukernels.mlir"
"gpu_materialize_encoding.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// RUN: iree-opt %s --split-input-file --mlir-print-local-scope \
// RUN: --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-greedily-distribute-to-threads, canonicalize, cse))" | \
// RUN: FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32>
attributes {
translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64, {}>
} {
%6 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]
} ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%5 : tensor<64x256xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.addf %in, %in_0 : f32
linalg.yield %7 : f32
} -> tensor<64x256xf32>
return %6 : tensor<64x256xf32>
}

// CHECK-LABEL: func.func @simple_generic
// CHECK: scf.forall
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x4xf32>)
// CHECK: scf.forall.in_parallel
// CHECK: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @fuse_destination(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>) -> tensor<64x64xf32>
attributes {
translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64, {}>
} {
%empty = tensor.empty() : tensor<64x64xf32>
%cst = arith.constant 0.0 : f32
%5 = linalg.fill ins(%cst : f32) outs(%empty : tensor<64x64xf32>) -> tensor<64x64xf32>
%7 = linalg.matmul ins(%3, %4 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%5 : tensor<64x64xf32>) -> tensor<64x64xf32>
return %7 : tensor<64x64xf32>
}

// CHECK-LABEL: func.func @fuse_destination
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64x64xf32>
// CHECK: scf.forall {{.*}} shared_outs(%[[INIT:.+]] = %[[EMPTY]]
// CHECK: linalg.fill {{.*}} -> tensor<1x1xf32>

// Additionally verify that reduction dimensions do not get tiled.
// CHECK: linalg.matmul ins({{.*}}: tensor<1x64xf32>, tensor<64x1xf32>)

// -----

func.func @in_nested_region(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32>
attributes {
translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64, {}>
} {
%c8 = arith.constant 8 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%6 = scf.for %arg0 = %c0 to %c64 step %c8 iter_args(%arg1 = %5) -> (tensor<64x64xf32>) {
%extracted_slice = tensor.extract_slice %3[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32>
%7 = linalg.add
ins(%extracted_slice, %extracted_slice_0 : tensor<64x8xf32>, tensor<64x8xf32>)
outs(%extracted_slice_1 : tensor<64x8xf32>) -> tensor<64x8xf32>
%insert = tensor.insert_slice %7 into %arg1[0, %arg0] [64, 8] [1, 1] : tensor<64x8xf32> into tensor<64x64xf32>
scf.yield %insert : tensor<64x64xf32>
}
return %6 : tensor<64x64xf32>
}

// CHECK-LABEL: func.func @in_nested_region
// CHECK: scf.for
// CHECK: scf.forall
// CHECK: linalg.add {{.*}} -> tensor<1x4xf32>

// -----

func.func @do_not_redistribute_in_forall(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32>
attributes {
translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64, {}>
} {
%c8 = arith.constant 8 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%6 = scf.for %arg0 = %c0 to %c64 step %c8 iter_args(%arg1 = %5) -> (tensor<64x64xf32>) {
%extracted_slice = tensor.extract_slice %3[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32>
%1 = scf.forall (%arg5, %arg6) = (0, 0) to (64, 8) step (1, 4) shared_outs(%arg7 = %extracted_slice_1) -> (tensor<64x8xf32>) {
%extracted_slice_2 = tensor.extract_slice %extracted_slice[%arg5, %arg6] [1, 4] [1, 1] : tensor<64x8xf32> to tensor<1x4xf32>
%extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg5, %arg6] [1, 4] [1, 1] : tensor<64x8xf32> to tensor<1x4xf32>
%extracted_slice_4 = tensor.extract_slice %arg7[%arg5, %arg6] [1, 4] [1, 1] : tensor<64x8xf32> to tensor<1x4xf32>
%2 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<1x4xf32>, tensor<1x4xf32>) outs(%extracted_slice_4 : tensor<1x4xf32>) -> tensor<1x4xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %2 into %arg7[%arg5, %arg6] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<64x8xf32>
}
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
%7 = linalg.add
ins(%1, %1 : tensor<64x8xf32>, tensor<64x8xf32>)
outs(%1 : tensor<64x8xf32>) -> tensor<64x8xf32>
%insert = tensor.insert_slice %7 into %arg1[0, %arg0] [64, 8] [1, 1] : tensor<64x8xf32> into tensor<64x64xf32>
scf.yield %insert : tensor<64x64xf32>
}
%8 = linalg.add
ins(%6, %6 : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%6 : tensor<64x64xf32>) -> tensor<64x64xf32>
return %8 : tensor<64x64xf32>
}

// CHECK-LABEL: func.func @do_not_redistribute_in_forall
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<64x64xf32>
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<64x64xf32>
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor<64x64xf32>
// CHECK: scf.for {{.*}} iter_args(%[[FOR_ITER:.+]] = %[[ARG2]])
// CHECK: %[[INIT:.+]] = tensor.extract_slice %[[FOR_ITER]]

// Verify that the existing forall stays the same.
// CHECK: scf.forall {{.*}} shared_outs(%[[ITER:.+]] = %[[INIT]])
// CHECK: %[[DEST:.+]] = tensor.extract_slice %[[ITER]]
// CHECK: linalg.add {{.*}} outs(%[[DEST]] : tensor<1x4xf32>
// CHECK: scf.forall.in_parallel

// CHECK: %[[DIST_ADD:.+]] = scf.forall
// CHECK: linalg.add
// CHECK: scf.forall.in_parallel
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[DIST_ADD]]
// CHECK: scf.yield %[[INSERT]]
// CHECK: %[[RES_ADD:.+]] = scf.forall
// CHECK: linalg.add
// CHECK: scf.forall.in_parallel
// CHECK: return %[[RES_ADD]]

// -----

func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>) -> (tensor<64x256xf32>, tensor<256x64xf32>)
attributes {
translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64, {}>
} {
%add_empty = tensor.empty() : tensor<64x256xf32>
%6 = linalg.add
ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>)
outs(%add_empty : tensor<64x256xf32>) -> tensor<64x256xf32>
%transpose_empty = tensor.empty() : tensor<256x64xf32>
%7 = linalg.transpose
ins(%6 : tensor<64x256xf32>)
outs(%transpose_empty : tensor<256x64xf32>) permutation = [1, 0]
return %6, %7 : tensor<64x256xf32>, tensor<256x64xf32>
}

// CHECK-LABEL: func.func @multiple_use_tilable_op
// CHECK: %[[ADD_DIST:.+]] = scf.forall
// CHECK: %[[ADD:.+]] = linalg.add {{.*}} -> tensor<1x4xf32>
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice %[[ADD]]
// CHECK: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// CHECK: %[[T_DIST:.+]] = scf.forall
// CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<4x1xf32>
// CHECK: %[[T:.+]] = linalg.transpose ins(%[[FUSED_ADD]]
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice %[[T]]
// CHECK: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// CHECK: return %[[ADD_DIST]], %[[T_DIST]]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -113,7 +114,15 @@ SmallVector<int64_t> deriveLinalgOpThreadTileSizes(linalg::LinalgOp linalgOp,
int64_t vectorSize = kPreferredCopyNumBits /
getElementTypeOrSelf(linalgOp->getResultTypes()[0])
.getIntOrFloatBitWidth();
return getVectorTileSizesFromLoopRanges(loopRanges, numThreads, vectorSize);
SmallVector<int64_t> tileSizes =
getVectorTileSizesFromLoopRanges(loopRanges, numThreads, vectorSize);
for (auto [tileSize, iterType] :
llvm::zip(tileSizes, linalgOp.getIteratorTypesArray())) {
if (iterType == utils::IteratorType::reduction) {
tileSize = 0;
}
}
return tileSizes;
}

SmallVector<int64_t>
Expand Down
Loading

0 comments on commit d7378bb

Please sign in to comment.