Skip to content

Commit 47432c6

Browse files
authored
[Codegen] Add pass to verify workgroup distribution (#19186)
While general verification is not possible, when using `scf.forall` for workgroup distribution we have the opportunity for basic verification that all writes are located within the distributed loop. In particular, if we have any workgroup level loops, any write to global memory outside is assumed to be illegal. This happens after bufferization because it is impossible to do this verification before determining the memory space of every tensor in the dispatch. This pass is relatively lightweight (two walks, both of which should be short) and so is on by default for every CPU and GPU pipeline.
1 parent 35b495b commit 47432c6

File tree

10 files changed

+152
-4
lines changed

10 files changed

+152
-4
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ iree_compiler_cc_library(
154154
"UnrollAnnotatedLoops.cpp",
155155
"UserConfig.cpp",
156156
"VectorizeMemrefCopy.cpp",
157+
"VerifyWorkgroupDistribution.cpp",
157158
],
158159
hdrs = [
159160
"BufferizationAnalysis.h",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ iree_cc_library(
146146
"UnrollAnnotatedLoops.cpp"
147147
"UserConfig.cpp"
148148
"VectorizeMemrefCopy.cpp"
149+
"VerifyWorkgroupDistribution.cpp"
149150
DEPS
150151
::PassHeaders
151152
::PassesIncGen

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,4 +633,15 @@ def VectorizeMemrefCopyPass :
633633
let summary = "Vectorizes memref copy operations.";
634634
}
635635

636+
def VerifyWorkgroupDistributionPass :
637+
InterfacePass<"iree-codegen-verify-workgroup-distribution", "mlir::FunctionOpInterface"> {
638+
let summary = "Pass to verify proper distribution to workgroups.";
639+
let description = [{
640+
Pass to verify that all writes to global memory are explicitly mapped to
641+
workgroups. This means that in cases where we use loops (scf.forall) to
642+
manage distribution to workgroups, we require that all ops with write
643+
side effects are contained within a workgroup distributed loop.
644+
}];
645+
}
646+
636647
#endif // IREE_CODEGEN_COMMON_PASSES
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
9+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
10+
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
11+
#include "mlir/IR/Visitors.h"
12+
#include "mlir/Interfaces/FunctionInterfaces.h"
13+
#include "mlir/Interfaces/SideEffectInterfaces.h"
14+
15+
namespace mlir::iree_compiler {
16+
17+
#define GEN_PASS_DEF_VERIFYWORKGROUPDISTRIBUTIONPASS
18+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
19+
20+
namespace {
21+
22+
struct VerifyWorkgroupDistributionPass final
23+
: impl::VerifyWorkgroupDistributionPassBase<
24+
VerifyWorkgroupDistributionPass> {
25+
26+
void runOnOperation() override {
27+
FunctionOpInterface funcOp = getOperation();
28+
29+
WalkResult hasForall = funcOp.walk([&](scf::ForallOp forallOp) {
30+
if (forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
31+
forallOp)) {
32+
return WalkResult::interrupt();
33+
}
34+
return WalkResult::advance();
35+
});
36+
37+
// Without a workgroup level forall, either this is a single workgroup
38+
// dispatch, in which case no verification is needed, or this is already
39+
// distributed in which case verification is no longer possible.
40+
if (!hasForall.wasInterrupted()) {
41+
return;
42+
}
43+
44+
auto globalAddressSpace = IREE::HAL::DescriptorTypeAttr::get(
45+
&getContext(), IREE::HAL::DescriptorType::StorageBuffer);
46+
47+
// Walk in PreOrder so that parent operations are visited before children,
48+
// thus allowing all operations contained within workgroup foralls to be
49+
// skipped.
50+
WalkResult res = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
51+
if (auto forallOp = dyn_cast<scf::ForallOp>(op)) {
52+
// Skip ops contained within forall ops with workgroup mappings.
53+
if (forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
54+
forallOp)) {
55+
return WalkResult::skip();
56+
}
57+
}
58+
if (auto memoryEffectOp = dyn_cast<MemoryEffectOpInterface>(op)) {
59+
for (Value operand : memoryEffectOp->getOperands()) {
60+
auto type = dyn_cast<MemRefType>(operand.getType());
61+
if (!type ||
62+
!memoryEffectOp.getEffectOnValue<MemoryEffects::Write>(operand)) {
63+
continue;
64+
}
65+
66+
// Writes to non-global memory are fine.
67+
if (type.getMemorySpace() != globalAddressSpace) {
68+
continue;
69+
}
70+
71+
op->emitOpError(
72+
"write affecting operations on global resources are restricted "
73+
"to workgroup distributed contexts.");
74+
return WalkResult::interrupt();
75+
}
76+
}
77+
return WalkResult::advance();
78+
});
79+
80+
if (res.wasInterrupted()) {
81+
funcOp.emitOpError("failed on workgroup distribution verification");
82+
return signalPassFailure();
83+
}
84+
}
85+
};
86+
87+
} // namespace
88+
89+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ iree_lit_test_suite(
8686
"vectorize_memref_copy.mlir",
8787
"vectorize_tensor_pad.mlir",
8888
"vector_layout_analysis.mlir",
89+
"verify_workgroup_distribution.mlir",
8990
],
9091
include = ["*.mlir"],
9192
exclude = [

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ iree_lit_test_suite(
8282
"vector_layout_analysis.mlir"
8383
"vectorize_memref_copy.mlir"
8484
"vectorize_tensor_pad.mlir"
85+
"verify_workgroup_distribution.mlir"
8586
TOOLS
8687
FileCheck
8788
iree-opt
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: iree-opt %s --split-input-file --verify-diagnostics \
2+
// RUN: --pass-pipeline="builtin.module(func.func(iree-codegen-verify-workgroup-distribution))" \
3+
// RUN: | FileCheck %s
4+
5+
// expected-error@+1 {{op failed on workgroup distribution verification}}
6+
func.func @write_outside_workgroup_forall(%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>) {
7+
scf.forall (%arg0) in (32) {
8+
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
9+
%c0 = arith.constant 0 : index
10+
// expected-error@+1 {{write affecting operations on global resources are restricted to workgroup distributed contexts.}}
11+
memref.store %i, %out[%c0] : memref<32xi32, #hal.descriptor_type<storage_buffer>>
12+
return
13+
}
14+
15+
// -----
16+
17+
// CHECK: func @non_workgroup_write_outside_workgroup_forall
18+
func.func @non_workgroup_write_outside_workgroup_forall(
19+
%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>, %out2: memref<32xi32>) {
20+
scf.forall (%arg0) in (32) {
21+
memref.store %i, %out[%arg0] : memref<32xi32, #hal.descriptor_type<storage_buffer>>
22+
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
23+
%c0 = arith.constant 0 : index
24+
memref.store %i, %out2[%c0] : memref<32xi32>
25+
return
26+
}
27+
28+
// -----
29+
30+
// expected-error@+1 {{op failed on workgroup distribution verification}}
31+
func.func @write_nested_in_other_forall(%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>) {
32+
scf.forall (%arg0) in (32) {
33+
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
34+
%c0 = arith.constant 0 : index
35+
scf.forall (%arg1) in (32) {
36+
// expected-error@+1 {{write affecting operations on global resources are restricted to workgroup distributed contexts.}}
37+
memref.store %i, %out[%arg1] : memref<32xi32, #hal.descriptor_type<storage_buffer>>
38+
}
39+
return
40+
}

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,8 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager,
813813
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
814814
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
815815
FunctionLikeNest(modulePassManager)
816-
.addPass(createLLVMCPULowerExecutableTargetPass);
816+
.addPass(createLLVMCPULowerExecutableTargetPass)
817+
.addPass(createVerifyWorkgroupDistributionPass);
817818
}
818819

819820
variantPassManager.addPass(createReconcileTranslationInfoPass());

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,8 @@ void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager,
11841184
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
11851185
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
11861186
FunctionLikeNest(modulePassManager)
1187-
.addPass(createLLVMGPULowerExecutableTargetPass);
1187+
.addPass(createLLVMGPULowerExecutableTargetPass)
1188+
.addPass(createVerifyWorkgroupDistributionPass);
11881189
}
11891190
variantPassManager.addPass(createReconcileTranslationInfoPass());
11901191

@@ -1250,7 +1251,8 @@ void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) {
12501251
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
12511252
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
12521253
FunctionLikeNest(modulePassManager)
1253-
.addPass(createROCDLLowerExecutableTargetPass);
1254+
.addPass(createROCDLLowerExecutableTargetPass)
1255+
.addPass(createVerifyWorkgroupDistributionPass);
12541256
}
12551257
variantPassManager.addPass(createReconcileTranslationInfoPass());
12561258
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass());

compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,8 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &variantPassManager) {
631631
modulePassManager.addPass(
632632
createSPIRVLowerExecutableUsingTransformDialectPass());
633633
FunctionLikeNest(modulePassManager)
634-
.addPass(createSPIRVLowerExecutableTargetPass);
634+
.addPass(createSPIRVLowerExecutableTargetPass)
635+
.addPass(createVerifyWorkgroupDistributionPass);
635636
addMemRefLoweringPasses(modulePassManager);
636637
}
637638
variantPassManager.addPass(createReconcileTranslationInfoPass());

0 commit comments

Comments
 (0)