diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 4b5957a50186..12f68a3ff8c0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -61,6 +61,7 @@ iree_compiler_cc_library( "GPUDistributeSharedMemoryCopy.cpp", "GPUDistributionPatterns.cpp", "GPUGeneralizeNamedOps.cpp", + "GPUGreedilyDistributeToThreads.cpp", "GPUInferMemorySpace.cpp", "GPULowerToUKernels.cpp", "GPUMaterializeEncoding.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index b48869ffcabd..4de9d3f5fcf0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -59,6 +59,7 @@ iree_cc_library( "GPUDistributeSharedMemoryCopy.cpp" "GPUDistributionPatterns.cpp" "GPUGeneralizeNamedOps.cpp" + "GPUGreedilyDistributeToThreads.cpp" "GPUInferMemorySpace.cpp" "GPULowerToUKernels.cpp" "GPUMaterializeEncoding.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp new file mode 100644 index 000000000000..f39f09165752 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp @@ -0,0 +1,159 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "llvm/ADT/STLForwardCompat.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Interfaces/TilingInterface.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_GPUGREEDILYDISTRIBUTETOTHREADSPASS +#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" + +namespace { + +struct GPUGreedilyDistributeToThreadsPass final + : impl::GPUGreedilyDistributeToThreadsPassBase< + GPUGreedilyDistributeToThreadsPass> { + void runOnOperation() override; +}; +} // namespace + +/// Helper to tile and greedily fuse the given operation to threads. This uses +/// the iree_gpu.derived_thread_config logic internally to determine tile sizes +/// to use. This does not yield any fused operation and only replaces the tiling +/// root. +/// +/// If tiling fails this returns silently (tiling is best effort). Later +/// verification steps will throw an error if distribution does not occur. +static void tileToThreads(RewriterBase &rewriter, + TilingInterface tilingInterfaceOp) { + rewriter.setInsertionPoint(tilingInterfaceOp); + auto configAttr = + IREE::GPU::DerivedThreadConfigAttr::get(rewriter.getContext()); + SmallVector tileSizes = configAttr.getTilingLevelSizes( + rewriter, llvm::to_underlying(IREE::GPU::TilingLevel::Thread), + tilingInterfaceOp); + + // Pad the tile sizes with zero. + auto zero = rewriter.getIndexAttr(0); + int64_t numLoops = tilingInterfaceOp.getLoopIteratorTypes().size(); + if (tileSizes.size() > numLoops) { + return; + } + while (tileSizes.size() < numLoops) { + tileSizes.push_back(zero); + } + + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes); + tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); + + // Use a "descending" relative thread mapping (ordered from slowest to + // fastest varying ID). For example, [linear_dim_n, ..., linear_dim_0]. + SmallVector mapping; + int idx = 0; + for (auto size : tileSizes) { + if (!isConstantIntValue(size, 0)) { + unsigned mappingId = + static_cast(gpu::MappingId::LinearDim0) + idx++; + mapping.push_back(gpu::GPUThreadMappingAttr::get( + rewriter.getContext(), static_cast(mappingId))); + } + } + tilingOptions.setMapping(llvm::to_vector(llvm::reverse(mapping))); + + scf::SCFTileAndFuseOptions tileAndFuseOptions; + tileAndFuseOptions.setTilingOptions(tilingOptions); + + scf::SCFTileAndFuseOptions::ControlFnTy controlFn = + [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, + bool isDestinationOperand) + -> std::optional { + // Always fuse but never yield a replacement. + return scf::SCFTileAndFuseOptions::ControlFnResult{ + /*yieldProducerReplacement=*/false}; + }; + tileAndFuseOptions.setFusionControlFn(controlFn); + + FailureOr tiledResults = + scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, + tileAndFuseOptions); + if (failed(tiledResults)) { + return; + } + + // Perform the replacement of the tiling root. + for (OpResult res : tilingInterfaceOp->getResults()) { + if (auto replacement = tiledResults->replacements.lookup(res)) { + rewriter.replaceAllUsesWith(res, replacement); + } + } + + if (tilingInterfaceOp->use_empty()) { + rewriter.eraseOp(tilingInterfaceOp); + } +} + +/// Recursively process the given region and tile all tilable operations not +/// present within an `scf.forall` with thread/warp/lane mapping. +static void processRegion(RewriterBase &rewriter, Region *region) { + // Process the region blocks in reverse. + for (Block &block : llvm::reverse(region->getBlocks())) { + // Save a reversed list of operations within the block. Ops will be + // greedily tiled + fused in reverse so that if a producer can be fused + // with a consumer we only distribute the producer once via fusion. + SmallVector targetOps = + llvm::map_to_vector(llvm::reverse(block.getOperations()), + [](Operation &op) { return &op; }); + // Skip all unused ops (possibly from tiling). + for (Operation *op : targetOps) { + if (op->use_empty()) { + continue; + } + // Skip all operations contained within an `scf.forall` mapped to threads + // warps or lanes. These are already distributed and fine to leave as is. + if (auto forall = dyn_cast(op)) { + if (forallOpHasMappingType(forall)) { + continue; + } + } + + // If an op implements the tiling interface, try to greedily tile + fuse. + if (auto tilableOp = dyn_cast(op)) { + tileToThreads(rewriter, tilableOp); + continue; + } + + // Else recursively process all nested operations. + for (auto ®ion : op->getRegions()) { + processRegion(rewriter, ®ion); + } + } + } +} + +void GPUGreedilyDistributeToThreadsPass::runOnOperation() { + auto funcOp = getOperation(); + + IRRewriter rewriter(funcOp->getContext()); + for (auto ®ion : funcOp->getRegions()) { + processRegion(rewriter, ®ion); + } +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 1f22bad8ea77..d339adf95a26 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -76,6 +76,17 @@ def GPUGeneralizeNamedOpsPass : let summary = "Convert named Linalg ops to linalg.generic ops"; } +def GPUGreedilyDistributeToThreadsPass : + InterfacePass<"iree-codegen-gpu-greedily-distribute-to-threads", "mlir::FunctionOpInterface"> { + let summary = "Greedily distributes all remaining tilable ops to threads"; + let dependentDialects = [ + "::mlir::affine::AffineDialect", + "::mlir::gpu::GPUDialect", + "::mlir::scf::SCFDialect", + "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect", + ]; +} + def GPUInferMemorySpacePass : InterfacePass<"iree-codegen-gpu-infer-memory-space", "mlir::FunctionOpInterface"> { let summary = "Pass to infer and set the memory space for all alloc_tensor ops."; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 595e0eb90a74..1c17589b3691 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -26,6 +26,7 @@ iree_lit_test_suite( "gpu_distribute_scf_for.mlir", "gpu_distribute_shared_memory.mlir", "gpu_generalize_named_ops.mlir", + "gpu_greedily_distribute_to_threads.mlir", "gpu_infer_memory_space.mlir", "gpu_lower_to_ukernels.mlir", "gpu_combine_value_barriers.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 7c214e310a87..ac241ba35b85 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -23,6 +23,7 @@ iree_lit_test_suite( "gpu_distribute_scf_for.mlir" "gpu_distribute_shared_memory.mlir" "gpu_generalize_named_ops.mlir" + "gpu_greedily_distribute_to_threads.mlir" "gpu_infer_memory_space.mlir" "gpu_lower_to_ukernels.mlir" "gpu_materialize_encoding.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_greedily_distribute_to_threads.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_greedily_distribute_to_threads.mlir new file mode 100644 index 000000000000..355e1a77f0f2 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_greedily_distribute_to_threads.mlir @@ -0,0 +1,162 @@ +// RUN: iree-opt %s --split-input-file --mlir-print-local-scope \ +// RUN: --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-greedily-distribute-to-threads, canonicalize, cse))" | \ +// RUN: FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32> + attributes { + translation_info = #iree_codegen.translation_info + } { + %6 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"] + } ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%5 : tensor<64x256xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %7 = arith.addf %in, %in_0 : f32 + linalg.yield %7 : f32 + } -> tensor<64x256xf32> + return %6 : tensor<64x256xf32> +} + +// CHECK-LABEL: func.func @simple_generic +// CHECK: scf.forall +// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x4xf32>) +// CHECK: scf.forall.in_parallel +// CHECK: mapping = [#gpu.thread, #gpu.thread] + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @fuse_destination(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>) -> tensor<64x64xf32> + attributes { + translation_info = #iree_codegen.translation_info + } { + %empty = tensor.empty() : tensor<64x64xf32> + %cst = arith.constant 0.0 : f32 + %5 = linalg.fill ins(%cst : f32) outs(%empty : tensor<64x64xf32>) -> tensor<64x64xf32> + %7 = linalg.matmul ins(%3, %4 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%5 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %7 : tensor<64x64xf32> +} + +// CHECK-LABEL: func.func @fuse_destination +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64x64xf32> +// CHECK: scf.forall {{.*}} shared_outs(%[[INIT:.+]] = %[[EMPTY]] +// CHECK: linalg.fill {{.*}} -> tensor<1x1xf32> + +// Additionally verify that reduction dimensions do not get tiled. +// CHECK: linalg.matmul ins({{.*}}: tensor<1x64xf32>, tensor<64x1xf32>) + +// ----- + +func.func @in_nested_region(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32> + attributes { + translation_info = #iree_codegen.translation_info + } { + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %6 = scf.for %arg0 = %c0 to %c64 step %c8 iter_args(%arg1 = %5) -> (tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %3[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32> + %extracted_slice_0 = tensor.extract_slice %4[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32> + %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32> + %7 = linalg.add + ins(%extracted_slice, %extracted_slice_0 : tensor<64x8xf32>, tensor<64x8xf32>) + outs(%extracted_slice_1 : tensor<64x8xf32>) -> tensor<64x8xf32> + %insert = tensor.insert_slice %7 into %arg1[0, %arg0] [64, 8] [1, 1] : tensor<64x8xf32> into tensor<64x64xf32> + scf.yield %insert : tensor<64x64xf32> + } + return %6 : tensor<64x64xf32> +} + +// CHECK-LABEL: func.func @in_nested_region +// CHECK: scf.for +// CHECK: scf.forall +// CHECK: linalg.add {{.*}} -> tensor<1x4xf32> + +// ----- + +func.func @do_not_redistribute_in_forall(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32> + attributes { + translation_info = #iree_codegen.translation_info + } { + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %6 = scf.for %arg0 = %c0 to %c64 step %c8 iter_args(%arg1 = %5) -> (tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %3[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32> + %extracted_slice_0 = tensor.extract_slice %4[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32> + %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32> + %1 = scf.forall (%arg5, %arg6) = (0, 0) to (64, 8) step (1, 4) shared_outs(%arg7 = %extracted_slice_1) -> (tensor<64x8xf32>) { + %extracted_slice_2 = tensor.extract_slice %extracted_slice[%arg5, %arg6] [1, 4] [1, 1] : tensor<64x8xf32> to tensor<1x4xf32> + %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg5, %arg6] [1, 4] [1, 1] : tensor<64x8xf32> to tensor<1x4xf32> + %extracted_slice_4 = tensor.extract_slice %arg7[%arg5, %arg6] [1, 4] [1, 1] : tensor<64x8xf32> to tensor<1x4xf32> + %2 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<1x4xf32>, tensor<1x4xf32>) outs(%extracted_slice_4 : tensor<1x4xf32>) -> tensor<1x4xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg7[%arg5, %arg6] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<64x8xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + %7 = linalg.add + ins(%1, %1 : tensor<64x8xf32>, tensor<64x8xf32>) + outs(%1 : tensor<64x8xf32>) -> tensor<64x8xf32> + %insert = tensor.insert_slice %7 into %arg1[0, %arg0] [64, 8] [1, 1] : tensor<64x8xf32> into tensor<64x64xf32> + scf.yield %insert : tensor<64x64xf32> + } + %8 = linalg.add + ins(%6, %6 : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%6 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %8 : tensor<64x64xf32> +} + +// CHECK-LABEL: func.func @do_not_redistribute_in_forall +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<64x64xf32> +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<64x64xf32> +// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor<64x64xf32> +// CHECK: scf.for {{.*}} iter_args(%[[FOR_ITER:.+]] = %[[ARG2]]) +// CHECK: %[[INIT:.+]] = tensor.extract_slice %[[FOR_ITER]] + +// Verify that the existing forall stays the same. +// CHECK: scf.forall {{.*}} shared_outs(%[[ITER:.+]] = %[[INIT]]) +// CHECK: %[[DEST:.+]] = tensor.extract_slice %[[ITER]] +// CHECK: linalg.add {{.*}} outs(%[[DEST]] : tensor<1x4xf32> +// CHECK: scf.forall.in_parallel + +// CHECK: %[[DIST_ADD:.+]] = scf.forall +// CHECK: linalg.add +// CHECK: scf.forall.in_parallel +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[DIST_ADD]] +// CHECK: scf.yield %[[INSERT]] +// CHECK: %[[RES_ADD:.+]] = scf.forall +// CHECK: linalg.add +// CHECK: scf.forall.in_parallel +// CHECK: return %[[RES_ADD]] + +// ----- + +func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>) -> (tensor<64x256xf32>, tensor<256x64xf32>) + attributes { + translation_info = #iree_codegen.translation_info + } { + %add_empty = tensor.empty() : tensor<64x256xf32> + %6 = linalg.add + ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) + outs(%add_empty : tensor<64x256xf32>) -> tensor<64x256xf32> + %transpose_empty = tensor.empty() : tensor<256x64xf32> + %7 = linalg.transpose + ins(%6 : tensor<64x256xf32>) + outs(%transpose_empty : tensor<256x64xf32>) permutation = [1, 0] + return %6, %7 : tensor<64x256xf32>, tensor<256x64xf32> +} + +// CHECK-LABEL: func.func @multiple_use_tilable_op +// CHECK: %[[ADD_DIST:.+]] = scf.forall +// CHECK: %[[ADD:.+]] = linalg.add {{.*}} -> tensor<1x4xf32> +// CHECK: scf.forall.in_parallel +// CHECK: tensor.parallel_insert_slice %[[ADD]] +// CHECK: mapping = [#gpu.thread, #gpu.thread] +// CHECK: %[[T_DIST:.+]] = scf.forall +// CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<4x1xf32> +// CHECK: %[[T:.+]] = linalg.transpose ins(%[[FUSED_ADD]] +// CHECK: scf.forall.in_parallel +// CHECK: tensor.parallel_insert_slice %[[T]] +// CHECK: mapping = [#gpu.thread, #gpu.thread] +// CHECK: return %[[ADD_DIST]], %[[T_DIST]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp index cea77d54a809..4e534d47b27b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/TypeUtilities.h" @@ -113,7 +114,15 @@ SmallVector deriveLinalgOpThreadTileSizes(linalg::LinalgOp linalgOp, int64_t vectorSize = kPreferredCopyNumBits / getElementTypeOrSelf(linalgOp->getResultTypes()[0]) .getIntOrFloatBitWidth(); - return getVectorTileSizesFromLoopRanges(loopRanges, numThreads, vectorSize); + SmallVector tileSizes = + getVectorTileSizesFromLoopRanges(loopRanges, numThreads, vectorSize); + for (auto [tileSize, iterType] : + llvm::zip(tileSizes, linalgOp.getIteratorTypesArray())) { + if (iterType == utils::IteratorType::reduction) { + tileSize = 0; + } + } + return tileSizes; } SmallVector diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index f998865ae4ec..7caed6d4afa5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -405,6 +405,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, // Step 5. Greedily fuse parallel loops and hoist from serial loops. funcPassManager.addPass(IREE::GPU::createFuseAndHoistParallelLoopsPass()); + funcPassManager.addPass(createGPUGreedilyDistributeToThreadsPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createLoopInvariantCodeMotionPass()); @@ -420,6 +421,9 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, addGPUBufferizePasses(funcPassManager); // Step 8. Resolve remaining parallel loops. + funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass( + NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false, + /*normalizeForall=*/true})); funcPassManager.addPass(createGPUVerifyDistributionPass()); funcPassManager.addPass(createGPUDistributeForallPass());