Skip to content

Commit

Permalink
[Codegen][GPU] Add kernel config for LLVMGPUTileAndFuse (#17791)
Browse files Browse the repository at this point in the history
This adds kernel configuration logic for targeting simple thread
distribution of linalg-based dispatches on LLVMGPU. The configuration
logic is primarily copied from the same logic on the SPIR-V side due to
the already well tested heuristics there for the kinds of varied target
descriptions that are present on the SPIR-V side.

Currently this is locked behind a flag `iree-codegen-llvmgpu-test-tile-and-fuse-vectorize`.
Future patches will add specialized logic for matmul.
  • Loading branch information
qedawkins committed Aug 17, 2024
1 parent 7cf3fc6 commit 10ba28d
Show file tree
Hide file tree
Showing 10 changed files with 435 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:Support",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ iree_cc_library(
MLIRFunctionInterfaces
MLIRIR
MLIRLinalgDialect
MLIRLinalgUtils
MLIRSupport
iree::compiler::Codegen::Common::GPU::GPUHeuristics
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Utils
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
Expand Down Expand Up @@ -201,4 +203,269 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
workgroupSize, targetSubgroupSize);
}

LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
// Bail out on multi result cases as consumer fusion currently does not
// support multi result ops.
if (!linalgOp || linalgOp.getNumDpsInits() != 1) {
return failure();
}

// This pipeline requires tensor semantics. Also fail for gather semantics
// for now to simplify tile + fuse.
if (!linalgOp.hasPureTensorSemantics() || linalgOp.hasIndexSemantics()) {
return failure();
}

SmallVector<unsigned int> partitionableLoops;
linalgOp.getParallelDims(partitionableLoops);

// Bail out if op is not tilable.
if (partitionableLoops.empty()) {
return failure();
}

const int subgroupSize = target.getPreferredSubgroupSize();
const unsigned loopDepth = linalgOp.getNumLoops();

// Configurations we need to decide.
std::array<int64_t, 3> workgroupSize;
SmallVector<int64_t> workgroupTileSizes;
SmallVector<int64_t> threadTileSizes;

// Initialize the configuration.
auto initConfiguration = [&]() {
workgroupSize = {subgroupSize, 1, 1};
workgroupTileSizes.resize(loopDepth, 0);
threadTileSizes.resize(loopDepth, 0);

// Initialize tiling along all partitioned loops with size 1.
for (int64_t loopIndex : partitionableLoops) {
workgroupTileSizes[loopIndex] = threadTileSizes[loopIndex] = 1;
}
// Override the innermost dimension to distribute to threads in a subgroup.
workgroupTileSizes[partitionableLoops.back()] = subgroupSize;
};

// Common case for all linalg ops.

// The core idea is to distribute the partitioned loops to the workgroup
// dimensions. The goal is to fill up the GPU as much as possible, which means
// 1) distributing to as many threads as possible, and 2) avoid assigning too
// many threads to handle out-of-bound elements (thus idle).

auto elementHasPowerOfTwoBitwidth = [](Value operand) {
Type elementType = getElementTypeOrSelf(operand.getType());
return isa<IntegerType, FloatType>(elementType) &&
llvm::isPowerOf2_64(IREE::Util::getTypeBitWidth(elementType));
};

// Whether we can try to use the vectorization pipeline.
SmallVector<int64_t> loopBounds = linalgOp.getStaticLoopRanges();
bool projPerm =
llvm::all_of(linalgOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isProjectedPermutation(); });
bool powTwo =
llvm::all_of(linalgOp->getOperands(), elementHasPowerOfTwoBitwidth);
bool staticShape = llvm::none_of(loopBounds, ShapedType::isDynamic);

// Require all affine maps to be projected permutation so that we can
// generate vector transfer ops.
bool vectorizable = projPerm && powTwo && staticShape;

const unsigned minBitwidth = getMinElementBitwidth(linalgOp);
// Make sure we use a tile size that results in some integral number of bytes.
const unsigned scaleToByte =
std::max(8 / minBitwidth, static_cast<unsigned>(1));

// Distribute workload to the given `numThreads` by allowing a potental loss.
auto distributeToThreads = [&](int64_t numThreads,
std::optional<int64_t> lossFactor =
std::nullopt) {
LDBG("Loss factor: " << lossFactor << "\n");
initConfiguration();
// If there are more than 3 parallel dim try to tile the extra higher level
// dimensions to 1 for extra dimensions.
if (isa<linalg::GenericOp>(linalgOp.getOperation())) {
for (auto [i, tileSize] : llvm::enumerate(workgroupTileSizes)) {
if (tileSize != 0)
break;
if (loopBounds[i] != 1)
tileSize = 1;
}
}
// Scan from the innermost shape dimension and try to deduce the
// configuration for the corresponding GPU workgroup dimension.
int64_t wgDim = 0;
for (auto shapeDim : llvm::reverse(partitionableLoops)) {
int64_t loopBound = loopBounds[shapeDim];
// Skip dynamic dimensions.
if (ShapedType::isDynamic(loopBound))
continue;

// Try to find some power of two that can devide the current shape dim
// size. This vector keeps the candidate tile sizes.
SmallVector<int64_t, 8> candidates;

// For the inner most workgroup dim, try to see if we can have 4
// elements per thread. This enables vectorization.
if (vectorizable && wgDim == 0 && !lossFactor) {
candidates.push_back(4 * numThreads);
}
// Try all power of two numbers up to the subgroup size.
for (unsigned i = numThreads; i >= 1; i >>= 1) {
candidates.push_back(i);
}
LLVM_DEBUG({
llvm::dbgs() << "Base candidate tile sizes: [";
llvm::interleaveComma(candidates, llvm::dbgs());
llvm::dbgs() << "]\n";
});

for (int64_t candidate : candidates) {
int64_t scaledTileSize = candidate * scaleToByte;
if (loopBound % scaledTileSize != 0) {
if (!lossFactor)
continue;
// Skip this candidate if it causes many threads to be idle.
int64_t idleThreads = candidate - (loopBound % scaledTileSize);
if (idleThreads > candidate / *lossFactor)
continue;
}
// If the workload is too small and we cannot distribute to more than 2
// workgroups, try a smaller tile size to increase parallelism.
if (partitionableLoops.size() == 1 && candidate > subgroupSize &&
llvm::divideCeil(loopBound, scaledTileSize) <= 2) {
continue;
}

// Found a suitable candidate. Try to let each thread handle 4
// elements if this is the workgroup x dimension.
// TODO: Try to take into account element type bit width to get
// 4xdword reads instead of 4x{elements}.
workgroupTileSizes[shapeDim] = scaledTileSize;
LLVM_DEBUG(llvm::dbgs()
<< "Chosen workgroup tile size: " << scaledTileSize << "\n");
if (vectorizable && wgDim == 0 && !lossFactor && candidate % 4 == 0) {
// Use size-1 vectors to increase parallelism if larger ones causes
// idle threads in the subgroup.
bool hasIdleThreads =
partitionableLoops.size() == 1 && candidate <= subgroupSize;
int vectorSize = hasIdleThreads ? 1 : 4;
LLVM_DEBUG(llvm::dbgs() << "Use vector size: " << vectorSize << "\n");
threadTileSizes[shapeDim] = vectorSize * scaleToByte;
workgroupSize[wgDim] = candidate / vectorSize;
assert(numThreads % (candidate / vectorSize) == 0);
numThreads /= candidate / vectorSize;
} else {
if (wgDim == 0)
vectorizable = false;
threadTileSizes[shapeDim] = scaleToByte;
workgroupSize[wgDim] = candidate;
assert(numThreads % candidate == 0);
numThreads /= candidate;
}
assert(numThreads >= 1);
break;
}

// Stop if we have distributed all threads.
if (numThreads == 1)
break;
wgDim++;
}
return numThreads;
};

// First try to see if we can use up all threads without any loss.
if (distributeToThreads(subgroupSize) != 1) {
// Otherwise, allow larger and larger loss factor.

// Threads for distribution. Use 32 at least.
int64_t numThreads = std::max(subgroupSize, 32);
// We can tolerate (1 / lossFactor) of threads in the workgroup to be idle.
int64_t lossFactor = 32;

for (; lossFactor >= 1; lossFactor >>= 1) {
if (distributeToThreads(numThreads, lossFactor) == 1)
break;
}
}

// TODO(qedawkins): Currently scf.forall resolution only supports static
// trip counts, meaning the workgroup tile size must perfectly divide the
// loop bound (and thread tile size must perfectly divide the workgroup tile)
// so that the trip count won't be static. Remove this check once proper
// dynamic trip count resolution support is added.
for (auto [loopId, threadTile] : llvm::enumerate(threadTileSizes)) {
if (threadTile == 0) {
continue;
}
int64_t bound = loopBounds[loopId];
int64_t wkgpTile = workgroupTileSizes[loopId];
if (bound % wkgpTile != 0 || wkgpTile % threadTile != 0) {
return failure();
}
}

TileSizesListType tileSizes;
tileSizes.push_back(workgroupTileSizes);
tileSizes.push_back(threadTileSizes);

// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
MLIRContext *context = linalgOp.getContext();
SmallVector<NamedAttribute, 1> attrs;
Builder b(context);
attrs.emplace_back(StringAttr::get(context, "workgroup"),
b.getIndexArrayAttr(workgroupTileSizes));

attrs.emplace_back(StringAttr::get(context, "thread"),
b.getIndexArrayAttr(threadTileSizes));

// Heuristic value chosen to limit maximum vector sizes when tiling below.
const unsigned maxVectorSize = 32;

// Try to tile all reductions by some small factor, preferrably 4, when
// possible. This gives us a chance to perform vector4 load if an input has
// its innnermost dimension being reduction. It also avoids generating too
// many instructions when unrolling vector later. We limit the expected
// vector size by estimating it from the size of the iteration space tile and
// limit it to a reasonable value. We process the loops from inner most to
// outer most to try to align loads along inner dimensions.
int64_t vectorSize = 1;
int64_t numLoops = linalgOp.getNumLoops();
SmallVector<utils::IteratorType> iterTypes = linalgOp.getIteratorTypesArray();
SmallVector<int64_t> loopTileSizes(numLoops, 0);
for (auto [reverseIdx, iter] : llvm::enumerate(llvm::reverse(iterTypes))) {
unsigned i = numLoops - reverseIdx - 1;
if (linalg::isReductionIterator(iter) || i >= workgroupTileSizes.size() ||
workgroupTileSizes[i] == 0) {
int64_t tileSize = getReductionTilingFactor(loopBounds[i]);
if (vectorSize * tileSize > maxVectorSize) {
tileSize = 1;
}
vectorSize *= tileSize;
loopTileSizes[i] = tileSize;
}
}
if (llvm::any_of(loopTileSizes, [](int64_t s) { return s != 0; })) {
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getIndexArrayAttr(loopTileSizes));
}

auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);

LDBG("Selected tile and fuse lowering config: " << loweringConfig << "\n");

// TODO(qedawkins): Use a shared pipeline identifier here.
return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, loweringConfig,
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse,
workgroupSize, subgroupSize, DictionaryAttr());
}

} // namespace mlir::iree_compiler::IREE::GPU
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *op);

/// Helper for setting up a default tile and fuse config for targeting
/// simple thread distribution. Currently restricted to linalg ops.
LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *op);

} // namespace mlir::iree_compiler::IREE::GPU

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TARGETUTILS_CONFIGUTILS_H_
30 changes: 23 additions & 7 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,15 @@
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace mlir::iree_compiler {

llvm::cl::opt<bool> clGPUEnableTileAndFuse(
"iree-codegen-llvmgpu-use-tile-and-fuse",
llvm::cl::desc("enable the usage of the tile and fuse pipeline"),
llvm::cl::opt<bool> clGPUTestTileAndFuseMatmul(
"iree-codegen-llvmgpu-test-tile-and-fuse-matmul",
llvm::cl::desc("test the the tile and fuse pipeline for matmul"),
llvm::cl::init(false));

llvm::cl::opt<bool> clGPUTestTileAndFuseVectorize(
"iree-codegen-llvmgpu-test-tile-and-fuse-vectorize",
llvm::cl::desc(
"test the tile and fuse pipeline for all supported operations"),
llvm::cl::init(false));

llvm::cl::opt<bool> clGPUEnableVectorDistribution(
Expand Down Expand Up @@ -1946,10 +1952,19 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
LDBG("Transform Dialect Config");
return success();
}
if (clGPUEnableTileAndFuse && succeeded(IREE::GPU::setMatmulLoweringConfig(
target, entryPointFn, computeOp))) {
LDBG("Tile and fuse matmul config");
return success();
if (clGPUTestTileAndFuseMatmul) {
if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn,
computeOp))) {
LDBG("Tile and fuse matmul config");
return success();
}
}
if (clGPUTestTileAndFuseVectorize) {
if (succeeded(IREE::GPU::setTileAndFuseLoweringConfig(target, entryPointFn,
computeOp))) {
LDBG("Tile and fuse default config");
return success();
}
}
if (succeeded(setVectorDistributionConfig(target, entryPointFn, computeOp))) {
return success();
Expand Down Expand Up @@ -2070,6 +2085,7 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) {
}
}
}
// Translation info (lowering pipeline) is already set.
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
if (succeeded(setWarpReductionConfig(target, entryPointFn, linalgOp))) {
return success();
}
// TODO: Add configurations for matmul here too.
if (succeeded(IREE::GPU::setTileAndFuseLoweringConfig(target, entryPointFn,
computeOp))) {
return success();
}
}

return failure();
Expand Down Expand Up @@ -386,7 +391,10 @@ LogicalResult initROCDLLaunchConfig(FunctionOpInterface funcOp) {
if (failed(setRootConfig(target, funcOp, rootOp)))
return failure();

propagateLoweringConfig(rootOp, computeOps);
if (getTranslationInfo(funcOp).getDispatchLoweringPassPipeline() !=
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
propagateLoweringConfig(rootOp, computeOps);
}
return success();
}

Expand Down
Loading

0 comments on commit 10ba28d

Please sign in to comment.