Skip to content

Commit

Permalink
[AMDGPU] Use shared memory in multi_mma ukernel (#19786)
Browse files Browse the repository at this point in the history
This achieves about 210 Top/s on CPX-mode MI300X, about 64% of peak 327
Top/s. That's about parity with the non-ukernel codegen path, which also
uses shared memory.

An earlier revision of this PR was opting out of DistributeMmaToLanes,
which was more natural since a kernel that uses shared memory has to
perform workgroup-relative indexing in the copies from global to shared
memory.

That required fine ordering of the pass pipeline, and ended up
performing worse, at 180 Top/s vs 210 Top/s. So this PR instead stays on
DistributeMmaToLanes, and then adds the negative thread-relative offsets
to compensate.

This relies on interpreting bitcode to tell exactly how much shared
memory to allocate. That takes 2 ms. To avoid doing it redundantly, this
is cached, with the `DataTiledMMAAttr` value as key, so this should only
run a few times per iree-compile invocation.

When it is determined that no shared memory should be allocated, to
avoid creating 0-sized tensors, a new `iree_codegen.null_pointer` type
is introduced to be passed in lieu of an actual tensor. It lowers to a
null pointer (and offset). It is intended to be used with ukernels
taking a tensor/memref/pointer argument that is nullable, such as the
shared memory argument here.

---------

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Jan 31, 2025
1 parent 0159762 commit 4693b1c
Show file tree
Hide file tree
Showing 24 changed files with 608 additions and 101 deletions.
21 changes: 3 additions & 18 deletions compiler/plugins/target/ROCM/builtins/ukernel/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,14 @@ _Float16 __ockl_wfred_max_f16(_Float16);
int64_t __ockl_wfred_min_i64(int64_t);
int32_t __ockl_wfred_min_i32(int32_t);

#define __CLK_LOCAL_MEM_FENCE 0x01
typedef unsigned __cl_mem_fence_flags;

static inline void __threadfence_block() {
__builtin_amdgcn_fence(__ATOMIC_SEQ_CST, "workgroup");
}

static inline void __work_group_barrier(__cl_mem_fence_flags flags) {
if (flags) {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");
} else {
__builtin_amdgcn_s_barrier();
}
}

static inline void __barrier(int n) {
__work_group_barrier((__cl_mem_fence_flags)n);
}

[[clang::convergent]] static inline void __syncthreads() {
__barrier(__CLK_LOCAL_MEM_FENCE);
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,176 @@

#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"

// Very naive kernel. TODO(bjacob):
// 1. Shared memory: can't allocate it within the microkernel (which is just a
// helper device function, not the actual amdgpu_kernel). Need to get it
// passed down here as additional parameters.
// 2. Better scheduling via either barrier intrinsics or inline assemby.
[[clang::always_inline]] void iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8(
const int8_t *a_buffer, int64_t a_offset, const int8_t *b_buffer,
int64_t b_offset, int32_t *c_buffer, int64_t c_offset, int32_t k_size,
// Encodes some information about a A/B fragment tile;
typedef struct ab_tile_info_t {
// Terminology: "_vecs":
// We will be counting in units of "vectors", meaning, for each A/B fragment
// the corresponding operand type of this particular MFMA intrinsic.
// For A and B, that type is i64, used as <8 x i8>.

// Number of vectors in the tile.
int num_vecs;
// Number of vectors that we store in shared memory. That is typically equal
// to num_vecs if using shared memory for the tile, or 0 otherwise.
int num_shared_vecs;
} ab_tile_info_t;

static ab_tile_info_t get_ab_tile_info(int tile_intrinsics, int tile_subgroups,
int opposite_tile_subgroups) {
ab_tile_info_t info;
info.num_vecs = /*subgroup size*/ 64 * tile_intrinsics * tile_subgroups;
// Use shared memory if the opposite tile has more than 1 subgroup, so that
// using shared memory would amortize loads from global memory.
info.num_shared_vecs = opposite_tile_subgroups > 1 ? info.num_vecs : 0;
return info;
}

static int32_t get_shared_memory_bytes(ab_tile_info_t a_tile,
ab_tile_info_t b_tile) {
// For this MFMA intrinsic, the A and B vector types are 8 bytes.
return 8 * (a_tile.num_shared_vecs + b_tile.num_shared_vecs);
}

// The bitcode of this function is interpreted during IREE compilation to
// determine the exact shared_memory_bytes to pass to the ukernel.
int32_t iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_query_shared_memory_bytes(
int32_t intrinsics_m, int32_t subgroups_m, int32_t intrinsics_n,
int32_t subgroups_n, int32_t intrinsics_k) {
// Load existing accumulators. The VLA becomes a normal array after inlining.
int32x4_t c[intrinsics_m][intrinsics_n];
int32x4_t *c_global = (int32x4_t *)(c_buffer + c_offset);
ab_tile_info_t a_tile =
get_ab_tile_info(intrinsics_m * intrinsics_k, subgroups_m, subgroups_n);
ab_tile_info_t b_tile =
get_ab_tile_info(intrinsics_n * intrinsics_k, subgroups_n, subgroups_m);
return get_shared_memory_bytes(a_tile, b_tile);
}

// Microkernel for iree_gpu.multi_mma with DataTiledMMAAttr with
// intrinsic = MFMA_I32_16x16x32_I8 and a shape with outer M and N dimensions
// equal to 1 (so that this is just doing the inner loop on the K dimension).
//
// This microkernel uses a shared memory workspace buffer provided by the
// caller. It is used to copy tiles of the A and/or B matrices, depending on
// which ones are reused by multiple subgroups.
//
// Note that the A, B, C matrix pointers are all after thread-distribution.
// When the pointer before thread-distribution is needed (when copying data
// into shared memory), care is taken to subtract the thread-relative offset,
// which is computed from the thread id.
//
// As this function is always_inline, some of its parameters are actually
// constant values after inlining, so some for() loops and if() branches here
// are actually unrolled/resolved at compile time, making this microkernel
// a generic "template". This is summarized in the below table.
//
// Parameters | Constant? | Description
// --------------------------- | ---------- | -----------
// a_base, a_offset | No | A-matrix pointer (thread-distrib.)
// b_base, b_offset | No | B-matrix pointer (thread-distrib.)
// c_base, c_offset | No | C-matrix pointer (thread-distrib.)
// shared_memory_{base,offset} | No | Shared memory workspace pointer
// shared_memory_bytes | Yes | Shared memory workspace size
// k_size | From shape | Size of outer K dimension
// intrinsics_m, subgroups_m | Yes | See DataTiledMMAAttr
// intrinsics_n, subgroups_n | Yes | See DataTiledMMAAttr
// intrinsics_k | Yes | See DataTiledMMAAttr
//
// TODO(bjacob): Better scheduling via either barrier intrinsics or inline asm.
[[clang::always_inline]] void iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8(
const int8_t *a_base, int64_t a_offset, const int8_t *b_base,
int64_t b_offset, int32_t *c_base, int64_t c_offset,
int8_t *shared_memory_base, int64_t shared_memory_offset,
int32_t shared_memory_bytes, int32_t k_size, int32_t intrinsics_m,
int32_t subgroups_m, int32_t intrinsics_n, int32_t subgroups_n,
int32_t intrinsics_k) {
ab_tile_info_t a_tile =
get_ab_tile_info(intrinsics_m * intrinsics_k, subgroups_m, subgroups_n);
ab_tile_info_t b_tile =
get_ab_tile_info(intrinsics_n * intrinsics_k, subgroups_n, subgroups_m);

// shared_memory_bytes should match exactly, as the value ultimately comes
// from the ..._query_shared_memory_bytes function defined just above.
if (shared_memory_bytes != get_shared_memory_bytes(a_tile, b_tile)) {
__builtin_trap();
}
// Set up our pointers to shared memory for A and B tiles.
int64_t *restrict a_shared =
(int64_t *)(shared_memory_base + shared_memory_offset);
int64_t *restrict b_shared = a_shared + a_tile.num_shared_vecs;

// Determine our thread id and the range for it.
int tid = __builtin_amdgcn_workitem_id_x();
int numthreads = 64 * subgroups_m * subgroups_n;
__builtin_assume(tid < numthreads);

// Compute the thread-relative data offsets.
int lane_id = tid % 64;
int subgroup_id = tid / 64;
int subgroup_n_idx = subgroup_id % subgroups_n;
int subgroup_m_idx = subgroup_id / subgroups_n;
int a_thread_relative_offset =
intrinsics_k * (lane_id + 64 * intrinsics_m * subgroup_m_idx);
int b_thread_relative_offset =
intrinsics_k * (lane_id + 64 * intrinsics_n * subgroup_n_idx);

// Set up pointers to global memory.
const int64_t *restrict a_global = (const int64_t *)(a_base + a_offset);
const int64_t *restrict b_global = (const int64_t *)(b_base + b_offset);
int32x4_t *restrict c_global = ((int32x4_t *)(c_base + c_offset));

// Load existing accumulators from global memory into registers.
// The VLA becomes a normal array after inlining.
int32x4_t c_regs[intrinsics_m][intrinsics_n];
for (int m = 0; m < intrinsics_m; ++m) {
for (int n = 0; n < intrinsics_n; ++n) {
c[m][n] = c_global[64 * (m * intrinsics_n + n)];
c_regs[m][n] = c_global[64 * (m * intrinsics_n + n)];
}
}

// Arithmetic loop.
const int64_t *a_global = (const int64_t *)(a_buffer + a_offset);
const int64_t *b_global = (const int64_t *)(b_buffer + b_offset);
for (int k_outer = 0; k_outer < k_size; ++k_outer) {
// Pointers to A/B data to feed MFMA, based on whether shared memory is
// used.
const int64_t *restrict a_mfma_vecs =
a_tile.num_shared_vecs ? a_shared + a_thread_relative_offset : a_global;
const int64_t *restrict b_mfma_vecs =
b_tile.num_shared_vecs ? b_shared + b_thread_relative_offset : b_global;

// If needed, load data from global to shared memory.
if (tid < a_tile.num_shared_vecs) { // Benefits from above __builtin_assume.
for (int i = 0; i < a_tile.num_shared_vecs; i += numthreads) {
a_shared[i + tid] = a_global[i + tid - a_thread_relative_offset];
}
}
if (tid < b_tile.num_shared_vecs) { // Benefits from above __builtin_assume.
for (int i = 0; i < b_tile.num_shared_vecs; i += numthreads) {
b_shared[i + tid] = b_global[i + tid - b_thread_relative_offset];
}
}
// Thread barrier if any shared memory is used.
if (a_tile.num_shared_vecs || b_tile.num_shared_vecs) {
__syncthreads();
}
// Load data from shared memory and perform arithmetic.
for (int m = 0; m < intrinsics_m; ++m) {
for (int n = 0; n < intrinsics_n; ++n) {
for (int k = 0; k < intrinsics_k; ++k) {
c[m][n] = __builtin_amdgcn_mfma_i32_16x16x32_i8(
a_global[64 * intrinsics_k * m + k],
b_global[64 * intrinsics_k * n + k], c[m][n], 0, 0, 0);
c_regs[m][n] = __builtin_amdgcn_mfma_i32_16x16x32_i8(
a_mfma_vecs[64 * intrinsics_k * m + k],
b_mfma_vecs[64 * intrinsics_k * n + k], c_regs[m][n], 0, 0, 0);
}
}
}
a_global += 64 * intrinsics_m * subgroups_m * intrinsics_k;
b_global += 64 * intrinsics_n * subgroups_n * intrinsics_k;
a_global += a_tile.num_vecs;
b_global += b_tile.num_vecs;
// Thread barrier if any shared memory is used.
if (a_tile.num_shared_vecs || b_tile.num_shared_vecs) {
__syncthreads();
}
}

// Store accumulators.
for (int m = 0; m < intrinsics_m; ++m) {
for (int n = 0; n < intrinsics_n; ++n) {
c_global[64 * (m * intrinsics_n + n)] = c[m][n];
c_global[64 * (m * intrinsics_n + n)] = c_regs[m][n];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x4x16x2x8xi8>,
// CHECK-NOT: promote_operands
// CHECK-SAME: reduction = [0, 0, 0]
// CHECK-SAME: #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
// CHECK-SAME: shared_memory_bytes = 8192
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -81,6 +84,24 @@ struct LowerArgmaxToUKernelPattern : OpRewritePattern<linalg::GenericOp> {
}
};

static Value createSharedMemory(PatternRewriter &rewriter, Location loc,
int sharedMemoryBytes) {
RankedTensorType tensorType =
RankedTensorType::get({sharedMemoryBytes}, rewriter.getI8Type());
ValueRange dynSizes{};
if (!sharedMemoryBytes) {
IREE::Codegen::NullPointerType nullPointerType =
IREE::Codegen::NullPointerType::get(rewriter.getContext());
return rewriter.create<IREE::Codegen::NullPointerOp>(loc, nullPointerType);
}
auto allocOp =
rewriter.create<bufferization::AllocTensorOp>(loc, tensorType, dynSizes);
Attribute sharedAddrSpace = gpu::AddressSpaceAttr::get(
rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
allocOp.setMemorySpaceAttr(sharedAddrSpace);
return allocOp;
}

struct LowerMultiMmaToUKernelPattern : OpRewritePattern<IREE::GPU::MultiMmaOp> {
LowerMultiMmaToUKernelPattern(MLIRContext *context)
: OpRewritePattern<IREE::GPU::MultiMmaOp>(context) {}
Expand All @@ -100,14 +121,16 @@ struct LowerMultiMmaToUKernelPattern : OpRewritePattern<IREE::GPU::MultiMmaOp> {
if (!mma) {
return rewriter.notifyMatchFailure(op, "unhandled MMAInterfaceAttr");
}
Location loc = op->getLoc();
Type I32Type = rewriter.getI32Type();
auto castIndexToI32 = [&](Value val) {
return rewriter.create<arith::IndexCastOp>(op.getLoc(),
rewriter.getI32Type(), val);
return rewriter.create<arith::IndexCastOp>(loc, I32Type, val);
};
auto constI32 = [&](int val) {
return rewriter.create<arith::ConstantIntOp>(op.getLoc(), val,
rewriter.getI32Type());
return rewriter.create<arith::ConstantIntOp>(loc, val, I32Type);
};
int64_t sharedMemoryBytes = ukernelAttr.getSharedMemoryBytes();
auto sharedMemory = createSharedMemory(rewriter, loc, sharedMemoryBytes);
Value k = castIndexToI32(
rewriter.create<tensor::DimOp>(op.getLoc(), op.getLhs(), 1));
Value intrinsicsM = constI32(mma.getIntrinsicsM());
Expand All @@ -118,8 +141,8 @@ struct LowerMultiMmaToUKernelPattern : OpRewritePattern<IREE::GPU::MultiMmaOp> {
rewriter.replaceOpWithNewOp<IREE::Codegen::UKernelGenericOp>(
op, TypeRange{op.getAccType()}, ukernelAttr.getName(),
ValueRange{op.getLhs(), op.getRhs()}, op.getAcc(),
ValueRange{k, intrinsicsM, subgroupsM, intrinsicsN, subgroupsN,
intrinsicsK},
ValueRange{sharedMemory, constI32(sharedMemoryBytes), k, intrinsicsM,
subgroupsM, intrinsicsN, subgroupsN, intrinsicsK},
ukernelAttr.getDefAttrs(),
/*strided_outer_dims=*/rewriter.getIndexAttr(0));
return success();
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def GPULowerToUKernelsPass :
"::mlir::iree_compiler::IREE::Codegen::IREECodegenDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::arith::ArithDialect",
"::mlir::bufferization::BufferizationDialect",
"::mlir::gpu::GPUDialect",
"::mlir::tensor::TensorDialect",
];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,33 @@ func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x1x1x2x8xi8>, %b : te
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, intrinsics_m = 8, intrinsics_n = 2, subgroups_n = 4, intrinsics_k = 2>,
lowering_config = #iree_gpu.lowering_config<{
reduction = [0, 0, 0],
ukernel = #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", def_attrs = {vm.import.module = "rocm"}>,
ukernel = #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", def_attrs = {vm.import.module = "rocm"}, shared_memory_bytes = 16384>,
workgroup = [1, 1, 0]}>
} : tensor<1x2x8x1x1x2x8xi8>, tensor<1x2x1x2x1x1x2x8xi8> into tensor<1x1x1x8x2x1x1x4xi32>
return %d : tensor<1x1x1x8x2x1x1x4xi32>
}

// CHECK-LABEL: func @multi_mma_mfma_i32_16x16x32_i8(
// CHECK-DAG: %c2_i32 = arith.constant 2 : i32
// CHECK-DAG: %c8_i32 = arith.constant 8 : i32
// CHECK-DAG: %c1_i32 = arith.constant 1 : i32
// CHECK-DAG: %c4_i32 = arith.constant 4 : i32
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic
// CHECK: bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<16384xi8>
// CHECK: iree_codegen.ukernel.generic
// CHECK-SAME: "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"

// -----

func.func @multi_mma_mfma_i32_16x16x32_i8_one_subgroup_no_shared_memory(%a : tensor<1x2x8x1x1x2x8xi8>, %b : tensor<1x2x1x2x1x1x2x8xi8>, %c : tensor<1x1x1x8x2x1x1x4xi32>) -> tensor<1x1x1x8x2x1x1x4xi32> {
%d = iree_gpu.multi_mma %a, %b, %c {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, intrinsics_m = 8, intrinsics_n = 2, intrinsics_k = 2>,
lowering_config = #iree_gpu.lowering_config<{
reduction = [0, 0, 0],
ukernel = #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", def_attrs = {vm.import.module = "rocm"}, shared_memory_bytes = 0>,
workgroup = [1, 1, 0]}>
} : tensor<1x2x8x1x1x2x8xi8>, tensor<1x2x1x2x1x1x2x8xi8> into tensor<1x1x1x8x2x1x1x4xi32>
return %d : tensor<1x1x1x8x2x1x1x4xi32>
}

// CHECK-LABEL: func @multi_mma_mfma_i32_16x16x32_i8_one_subgroup_no_shared_memory(
// CHECK: iree_codegen.null_pointer
// CHECK: iree_codegen.ukernel.generic
// CHECK-SAME: "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
// CHECK-SAME: (%c2_i32, %c8_i32, %c1_i32, %c2_i32, %c4_i32, %c2_i32 : i32, i32, i32, i32, i32, i32)
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,16 @@ func.func @ukernel_generic_test_fndef_attrs(%arg0 : memref<?xf32, strided<[1], o
}
// CHECK: func.func private @test1d(memref<f32>, index, index)
// CHECK-SAME: hal.import.fields = ["processor_id", "processor_data"]

// -----

func.func @ukernel_with_null_pointer_arg() {
%0 = iree_codegen.null_pointer
iree_codegen.ukernel.generic "foo" ins(%0: !iree_codegen.null_pointer)
return
}

// CHECK-LABEL: func.func private @foo(!iree_codegen.null_pointer, index)
// CHECK-DAG: %[[NULLPTR:.+]] = iree_codegen.null_pointer
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 : index
// CHECK: call @foo(%[[NULLPTR]], %[[ZERO]]) : (!iree_codegen.null_pointer, index) -> ()
Loading

0 comments on commit 4693b1c

Please sign in to comment.