Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subgroupmatmul memory optimzations #23758

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ rust/**/Cargo.lock
rust/onnxruntime/synset.txt

# Python
bin/
lib/
pyvenv.cfg

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1
dawn;https://github.com/google/dawn/archive/b9b4a37041dec3dd62ac92014a6cc1aece48d9f3.zip;e8b8c2ebabdedb7c57d931fc4a19ae22146d31e1
dawn;https://github.com/google/dawn/archive/40a9fa79f76e6c76cca9e2fa69ea07f202f1d2e6.zip;e224563d5ab4a8e53a517b06f721242533bce722
kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/d15722976120710080ca098fe8ddabf4556cb40f/kleidiai-d15722976120710080ca098fe8ddabf4556cb40f.zip;d6c840d00c3b05aedf06e957ddaece1013d1f40b
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <string_view>

#include "contrib_ops/webgpu/quantization/matmul_nbits.h"
#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/webgpu/shader_helper.h"
Expand Down Expand Up @@ -815,6 +816,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
uint32_t components = GetMaxComponents(N);

const bool has_zero_points = zero_points != nullptr;
// macOS - Experimental dawn support for subgroup matrix matmul on Metal.
if (M >= kMinMForTileOptimization &&
CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) {
return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y);
}

const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups);
// macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support.
// https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("input_b", ShaderUsage::UseUniform);
shader.AddInput("scales_b", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);

// tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm)
// https://github.com/ggml-org/llama.cpp/blob/d04e7163c85a847bc61d58c22f2c503596db7aa8/ggml/src/ggml-metal/ggml-metal.metal#L6066
shader.AdditionalImplementation() << R"ADDNL_FN(
const tile_cols = 64;
const tile_rows = 32;
const tile_k = 32;
const subtile_cols = 32;
const subtile_rows = 16;
const quantization_block_size = 32;
alias compute_precision = output_element_t;

var<workgroup> tile_A: array<compute_precision, tile_rows * tile_k>; // 32 x 32 - RxC
var<workgroup> tile_B: array<compute_precision, tile_cols * tile_k>; // 64 x 32 - RxC
var<workgroup> scratch: array<array<array<compute_precision, 64>, 4>, 4>; // 64 * 4 * 4

fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, col:u32) {
// Each call loads 4 rows, at col.
// 128 threads need to load 32 x 32. Organized as 32 threads load a row, 4 rows at a time.
var out_row = row;
let row_limit = min(tile_base + tile_rows, uniforms.M);
for (var a_global:u32 = tile_base + out_row; a_global < row_limit; a_global = tile_base + out_row)
{
tile_A[out_row * tile_k + col] = compute_precision(input_a[a_global*uniforms.K + k_idx + col]);
out_row+=4;
}
}

fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) {
// Each call loads 8 columns, starting at col and then loads another rows.
// 128 threads need to load 64 x 32. Organized as 4 threads load a row, 32 rows at a time.
// B is stored in column major fashion.
var col = c_idx * 8;
let out_row_limit = min(tile_base + tile_cols, uniforms.N);
var out_row = row;
for (var b_global:u32 = tile_base + out_row; b_global < out_row_limit; b_global = tile_base + out_row)
{
let b_idx = u32((b_global*uniforms.K + k_idx + col)/8);
let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]);
var b_value = input_b[b_idx];
var b_value_lower = (vec4<compute_precision>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<compute_precision>(8)) * scale;
var b_value_upper = (vec4<compute_precision>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<compute_precision>(8)) * scale;
let tile_b_base = out_row * tile_k + col;
tile_B[tile_b_base] = b_value_lower[0];
tile_B[tile_b_base + 1] = b_value_upper[0];
tile_B[tile_b_base + 2] = b_value_lower[1];
tile_B[tile_b_base + 3] = b_value_upper[1];
tile_B[tile_b_base + 4] = b_value_lower[2];
tile_B[tile_b_base + 5] = b_value_upper[2];
tile_B[tile_b_base + 6] = b_value_lower[3];
tile_B[tile_b_base + 7] = b_value_upper[3];
out_row += 32;
}
}

fn storeOutput(offset:u32, row: u32, col:u32, src_slot:u32, row_limit:i32) {
if (row_limit > 0 && row < u32(row_limit))
{
output[offset + row * uniforms.N + col] = output_element_t(scratch[src_slot][0][row * 8 + col]);
output[offset + row * uniforms.N + col + 8] = output_element_t(scratch[src_slot][1][row * 8 + col]);
output[offset + row * uniforms.N + col + 16] = output_element_t(scratch[src_slot][2][row * 8 + col]);
output[offset + row * uniforms.N + col + 24] = output_element_t(scratch[src_slot][3][row * 8 + col]);
let col2 = col + 1;
output[offset + row * uniforms.N + col2] = output_element_t(scratch[src_slot][0][row * 8 + col2]);
output[offset + row * uniforms.N + col2 + 8] = output_element_t(scratch[src_slot][1][row * 8 + col2]);
output[offset + row * uniforms.N + col2 + 16] = output_element_t(scratch[src_slot][2][row * 8 + col2]);
output[offset + row * uniforms.N + col2 + 24] = output_element_t(scratch[src_slot][3][row * 8 + col2]);
}
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
let a_global_base = workgroup_id.y * tile_rows;
let b_global_base = workgroup_id.x * tile_cols;

let subtile_id = u32(local_idx / sg_size);
let subtile_idx = u32(subtile_id / 2);
let subtile_idy = subtile_id % 2;
let base_A = subtile_idy * subtile_rows;
let base_B = subtile_idx * subtile_cols;

var matC00: subgroup_matrix_result<compute_precision, 8, 8>;
var matC01: subgroup_matrix_result<compute_precision, 8, 8>;
var matC02: subgroup_matrix_result<compute_precision, 8, 8>;
var matC03: subgroup_matrix_result<compute_precision, 8, 8>;
var matC10: subgroup_matrix_result<compute_precision, 8, 8>;
var matC11: subgroup_matrix_result<compute_precision, 8, 8>;
var matC12: subgroup_matrix_result<compute_precision, 8, 8>;
var matC13: subgroup_matrix_result<compute_precision, 8, 8>;
for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) {
// Load Phase
loadSHMA(a_global_base, kidx, local_idx/tile_k, local_idx%tile_k);
// tile_k in B's vectorization is 4. Since each U32 holds 8 weights.
loadSHMB(b_global_base, kidx, local_idx/4, local_idx%4);
workgroupBarrier();

for (var step: u32 = 0; step < tile_k; step+=8)
{
// Load to local memory phase
let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step;
// Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride
var matA0: subgroup_matrix_left<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_left<compute_precision, 8, 8>>(&tile_A, matrix_a_offset, false, tile_k);
var matA1: subgroup_matrix_left<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_left<compute_precision, 8, 8>>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k);

// tile_B is stored as column major.
// [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32]
var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step;
var matB0: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset, true, tile_k);
var matB1: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k);
var matB2: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k);
var matB3: subgroup_matrix_right<compute_precision, 8, 8> = subgroupMatrixLoad<subgroup_matrix_right<compute_precision, 8, 8>>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k);

// Compute Phase
// Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate
matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00);
matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01);
matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02);
matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03);

matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10);
matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11);
matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12);
matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13);
}

workgroupBarrier();
}

// Write out
// Write out top block
subgroupMatrixStore(&scratch[subtile_id][0], 0, matC00, false, 8);
subgroupMatrixStore(&scratch[subtile_id][1], 0, matC01, false, 8);
subgroupMatrixStore(&scratch[subtile_id][2], 0, matC02, false, 8);
subgroupMatrixStore(&scratch[subtile_id][3], 0, matC03, false, 8);
workgroupBarrier();
let row = u32(sg_id / 4);
var col = u32(sg_id % 4) * 2;
var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B;
var row_limit:i32 = i32(uniforms.M) - i32(a_global_base + base_A);
storeOutput(matrix_c_offset, row, col, subtile_id, row_limit);
workgroupBarrier();

// Write out bottom block
subgroupMatrixStore(&scratch[subtile_id][0], 0, matC10, false, 8);
subgroupMatrixStore(&scratch[subtile_id][1], 0, matC11, false, 8);
subgroupMatrixStore(&scratch[subtile_id][2], 0, matC12, false, 8);
subgroupMatrixStore(&scratch[subtile_id][3], 0, matC13, false, 8);
workgroupBarrier();
matrix_c_offset = matrix_c_offset + 8 * uniforms.N;
row_limit = i32(uniforms.M) - i32(a_global_base + base_A + 8);
storeOutput(matrix_c_offset, row, col, subtile_id, row_limit);
)MAIN_FN";

return Status::OK();
}

Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
uint32_t M,
uint32_t N,
uint32_t K,
onnxruntime::webgpu::ComputeContext& context,
Tensor* y) {
constexpr uint32_t kTileSizeA = 32;
constexpr uint32_t kTileSizeB = 64;
constexpr uint32_t kU32Components = 4;
TensorShape y_shape{1, M, N};
SubgroupMatrixMatMulNBitsProgram mul_program;
mul_program.SetWorkgroupSize(128);
mul_program.SetDispatchGroupSize(
(N + kTileSizeB - 1) / kTileSizeB,
(M + kTileSizeA - 1) / kTileSizeA, 1);
mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow<int>(1)});
return context.RunProgram(mul_program);
}

bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
uint64_t accuracy_level,
uint32_t block_size,
uint32_t batch_count,
uint32_t N,
uint32_t K,
bool has_zero_points) {
const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
// For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are
// some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy
// by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s,
// FP322 is around 7s.
return context.AdapterInfo().backendType == wgpu::BackendType::Metal &&
has_subgroup_matrix &&
accuracy_level == 4 &&
block_size == 32 &&
batch_count == 1 &&
K % 32 == 0 &&
N % 64 == 0 &&
!has_zero_points;
}
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

class SubgroupMatrixMatMulNBitsProgram final : public Program<SubgroupMatrixMatMulNBitsProgram> {
public:
SubgroupMatrixMatMulNBitsProgram() : Program{"SubgroupMatrixMatMulNBits"} {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32});
};

Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
uint32_t M,
uint32_t N,
uint32_t K,
onnxruntime::webgpu::ComputeContext& context,
Tensor* y);

bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
uint64_t accuracy_level,
uint32_t block_size,
uint32_t batch_count,
uint32_t N,
uint32_t K,
bool has_zero_points);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& sha
if (device_.HasFeature(wgpu::FeatureName::Subgroups)) {
ss << "enable subgroups;\n";
}
if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
ss << "enable chromium_experimental_subgroup_matrix;\n";
}

//
// Section constants
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ std::vector<wgpu::FeatureName> WebGpuContext::GetAvailableRequiredFeatures(const
constexpr wgpu::FeatureName features[]{
#if !defined(__wasm__)
wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses,
wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix,
#endif
wgpu::FeatureName::TimestampQuery,
wgpu::FeatureName::ShaderF16,
Expand Down Expand Up @@ -738,7 +739,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
// Step.2 - Create wgpu::Instance
#if !defined(__wasm__)
wgpu::InstanceDescriptor instance_desc{};
instance_desc.features.timedWaitAnyEnable = true;
instance_desc.capabilities.timedWaitAnyEnable = true;
default_instance_ = wgpu::CreateInstance(&instance_desc);
#else
default_instance_ = wgpu::CreateInstance(nullptr);
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ class WebGpuContext final {
wgpu::ComputePassDescriptor compute_pass_desc{};

if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) {
wgpu::ComputePassTimestampWrites timestampWrites = {
query_set_, num_pending_dispatches_ * 2, num_pending_dispatches_ * 2 + 1};
wgpu::PassTimestampWrites timestampWrites = {
nullptr,
query_set_,
num_pending_dispatches_ * 2,
num_pending_dispatches_ * 2 + 1};
compute_pass_desc.timestampWrites = &timestampWrites;
}

Expand Down
Loading