From 7e2b8db5145066210e4b02986c56b0b0ad49447f Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sun, 16 Feb 2025 16:33:13 -0800 Subject: [PATCH 01/14] Upgrade Dawn --- cmake/deps.txt | 2 +- onnxruntime/core/providers/webgpu/webgpu_context.cc | 2 +- onnxruntime/core/providers/webgpu/webgpu_context.h | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index b07a3acdecd54..d2e0fd63215f4 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -58,5 +58,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1 -dawn;https://github.com/google/dawn/archive/b9b4a37041dec3dd62ac92014a6cc1aece48d9f3.zip;e8b8c2ebabdedb7c57d931fc4a19ae22146d31e1 +dawn;https://github.com/google/dawn/archive/40a9fa79f76e6c76cca9e2fa69ea07f202f1d2e6.zip;e224563d5ab4a8e53a517b06f721242533bce722 kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/d15722976120710080ca098fe8ddabf4556cb40f/kleidiai-d15722976120710080ca098fe8ddabf4556cb40f.zip;d6c840d00c3b05aedf06e957ddaece1013d1f40b diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 50ace96524ddf..f3b025b72aa1d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -738,7 +738,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co // Step.2 - Create wgpu::Instance #if !defined(__wasm__) wgpu::InstanceDescriptor instance_desc{}; - instance_desc.features.timedWaitAnyEnable = true; + instance_desc.capabilities.timedWaitAnyEnable = true; default_instance_ = wgpu::CreateInstance(&instance_desc); #else default_instance_ = wgpu::CreateInstance(nullptr); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index a87a940f44437..41610cc251659 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -94,8 +94,10 @@ class WebGpuContext final { wgpu::ComputePassDescriptor compute_pass_desc{}; if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) { - wgpu::ComputePassTimestampWrites timestampWrites = { - query_set_, num_pending_dispatches_ * 2, num_pending_dispatches_ * 2 + 1}; + wgpu::PassTimestampWrites timestampWrites = { + .querySet = query_set_, + .beginningOfPassWriteIndex = num_pending_dispatches_ * 2, + .endOfPassWriteIndex = num_pending_dispatches_ * 2 + 1}; compute_pass_desc.timestampWrites = ×tampWrites; } From 91cda6287f10cc4650c10e0b4b8b3c126813f96f Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sun, 16 Feb 2025 17:25:16 -0800 Subject: [PATCH 02/14] Code Builds --- .../webgpu/quantization/matmul_nbits.cc | 6 + .../subgroup_matrix_matmul_nbits.cc | 223 ++++++++++++++++++ .../subgroup_matrix_matmul_nbits.h | 44 ++++ .../core/providers/webgpu/webgpu_context.cc | 1 + 4 files changed, 274 insertions(+) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 3b566d37fa979..15215082c7480 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -4,6 +4,7 @@ #include #include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" @@ -815,6 +816,11 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context uint32_t components = GetMaxComponents(N); const bool has_zero_points = zero_points != nullptr; + // macOS - Experimental dawn support for subgroup matrix matmul on Metal. + if (CanApplySubgroupMatrixMatMulNBits(context, block_size, batch_count, K, N, has_zero_points)) { + return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); + } + const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc new file mode 100644 index 0000000000000..d1392f91394b7 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_cols = 64; + const tile_rows = 32; + const tile_k = 32; + const subtile_cols = 32; + const subtile_rows = 16; + const quantization_block_size = 32; + + var tile_A: array; // 32 x 32 - RxC + var tile_B: array; // 32 x 64 - RxC + var scratch: array, 4>; // 64 * 4 + + fn loadSHMA(subtile_base: u32, k_idx: u32, row: u32, col:u32) { + let a_global = subtile_base + row; + if (a_global >= uniforms.M) { + return; + } + // Each call loads 8 columns, starting at col. + // 128 threads need to load 32 x 32. 4 threads per row or 8 col per thread. + for (var col_offset:u32 = 0; col_offset < 8; col_offset++) + { + tile_A[row * tile_rows + col+ col_offset] = input_a[a_global*uniforms.K + k_idx + col + col_offset]; + } + } + + fn loadSHMB(subtile_base: u32, k_idx: u32, row: u32, col: u32) { + let b_global = subtile_base + row; + if (b_global >= uniforms.N) { + return; + } + // Each call loads 16 columns, starting at col. + // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. + // Stored in column major fashion. + let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); + let scale = scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]; + for (var step:u32 = 0; step < 2; step++) + { + var b_value = input_b[b_idx+step]; + var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; + var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; + let tile_b_base = row * tile_rows + col + step * 8; + tile_B[tile_b_base] = b_value_lower[0]; + tile_B[tile_b_base + 1] = b_value_upper[0]; + tile_B[tile_b_base + 2] = b_value_lower[1]; + tile_B[tile_b_base + 3] = b_value_upper[1]; + tile_B[tile_b_base + 4] = b_value_lower[2]; + tile_B[tile_b_base + 5] = b_value_upper[2]; + tile_B[tile_b_base + 6] = b_value_lower[3]; + tile_B[tile_b_base + 7] = b_value_upper[3]; + } + } + + fn safeMatrixStore(offset: u32, mat: ptr>, rows:u32, subtile_id:u32, subtile_thread_id:u32) + { + subgroupMatrixStore(&scratch[subtile_id], 0, *mat, false, 8); + // There are 32 subtile_thread_id and we have 64 values. + let row = u32(subtile_thread_id / 4); + var col = u32(subtile_thread_id % 4) * 2; + if (row < rows) + { + output[offset + row * uniforms.N + col] = scratch[subtile_id][row * 8 + col]; + col++; + output[offset + row * uniforms.N + col] = scratch[subtile_id][row * 8 + col]; + } + } + )ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + let a_global_base = workgroup_idy * tile_rows; + let b_global_base = workgroup_idx * tile_cols; + + let subtile_id = u32(local_idx / sg_size); + let subtile_idx = u32(subtile_id / 2); + let subtile_idy = subtile_id % 2; + let base_A = subtile_idy * subtile_rows; + let base_B = subtile_idx * subtile_cols; + + var matC00: subgroup_matrix_result; + var matC01: subgroup_matrix_result; + var matC02: subgroup_matrix_result; + var matC03: subgroup_matrix_result; + var matC10: subgroup_matrix_result; + var matC11: subgroup_matrix_result; + var matC12: subgroup_matrix_result; + var matC13: subgroup_matrix_result; + for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { + // Load Phase + loadSHMA(a_global_base+base_A, kidx, local_idx/4, local_idx%4); + loadSHMB(b_global_base+base_B, kidx, local_idx/2, local_idx%2); + workgroupBarrier(); + + for (var step: u32 = 0; step < tile_k; step+=8) + { + // Load to local memory phase + let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; + // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); + + // tile_B is stored as column major. + // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] + var matrix_b_offset = subtile_idx * subtile_cols + step; + var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); + var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); + var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); + var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); + matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); + matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); + matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); + + matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); + matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); + matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); + matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); + } + + workgroupBarrier(); + } + + let subtile_thread_id = sg_id; + var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; + if (a_global_base+base_A+8 < uniforms.M) + { + // Syntax: subgroupMatrixStore destination, dest_offset, matrix, is_col_major, dest_stride + subgroupMatrixStore(&output, matrix_c_offset, matC00, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 8, matC01, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 16, matC02, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 24, matC03, false, uniforms.N); + } + else if (a_global_base + base_A < uniforms.M) + { + let rows = uniforms.M - (a_global_base + base_A); + safeMatrixStore(matrix_c_offset, &matC00, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 8, &matC01, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 16, &matC02, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 24, &matC03, rows, subtile_id, subtile_thread_id); + } + matrix_c_offset = matrix_c_offset + 8 * uniforms.N; + if (a_global_base+base_A+16 < uniforms.M) + { + subgroupMatrixStore(&output, matrix_c_offset, matC10, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 8, matC11, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 16, matC12, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 24, matC13, false, uniforms.N); + } + else if (a_global_base+base_A+8 < uniforms.M) + { + let rows = uniforms.M - (a_global_base + base_A + 8); + safeMatrixStore(matrix_c_offset, &matC10, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 8, &matC11, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 16, &matC12, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 24, &matC13, rows, subtile_id, subtile_thread_id); + } + )MAIN_FN"; + + return Status::OK(); +} + + +Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) +{ + constexpr uint32_t kTileSizeA = 32; + constexpr uint32_t kTileSizeB = 64; + constexpr uint32_t kU32Components = 4; + TensorShape y_shape{1, M, N}; + SubgroupMatrixMatMulNBitsProgram mul_program; + mul_program.SetWorkgroupSize(128); + mul_program.SetDispatchGroupSize( + (N + kTileSizeB - 1) / kTileSizeB, + (M + kTileSizeA - 1) / kTileSizeA, 1); + mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow(1)}); + return context.RunProgram(mul_program); +} + +bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + bool has_zero_points) +{ + const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + return context.AdapterInfo().backendType == wgpu::BackendType::Metal && + has_subgroup_matrix && + block_size == 32 && + batch_count == 1 && + K % 32 == 0 && + N % 64 == 0 && + !has_zero_points; +} +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h new file mode 100644 index 0000000000000..65988d3970b5e --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class SubgroupMatrixMatMulNBitsProgram final : public Program { + public: + SubgroupMatrixMatMulNBitsProgram() : Program{"SubgroupMatrixMatMulNBits"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}); +}; + +Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y); + +bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + bool has_zero_points); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index f3b025b72aa1d..163dd691b7f16 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -486,6 +486,7 @@ std::vector WebGpuContext::GetAvailableRequiredFeatures(const constexpr wgpu::FeatureName features[]{ #if !defined(__wasm__) wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix, #endif wgpu::FeatureName::TimestampQuery, wgpu::FeatureName::ShaderF16, From 9f838ff4c9a50912205bd6488aca8148d8c4eae4 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sun, 16 Feb 2025 17:48:46 -0800 Subject: [PATCH 03/14] Runs but garbage outputs --- .../webgpu/quantization/subgroup_matrix_matmul_nbits.cc | 4 ++-- onnxruntime/core/providers/webgpu/shader_helper.cc | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index d1392f91394b7..35d0c0fc54c01 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -81,8 +81,8 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( - let a_global_base = workgroup_idy * tile_rows; - let b_global_base = workgroup_idx * tile_cols; + let a_global_base = workgroup_id.y * tile_rows; + let b_global_base = workgroup_id.x * tile_cols; let subtile_id = u32(local_idx / sg_size); let subtile_idx = u32(subtile_id / 2); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 49c9a84f69551..6ff10ffb80d2c 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -352,6 +352,9 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { ss << "enable subgroups;\n"; } + if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + ss << "enable chromium_experimental_subgroup_matrix;\n"; + } // // Section constants From aaeb58b5adf4a5361f591a3d8878c583fd7e99d9 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sun, 16 Feb 2025 18:08:59 -0800 Subject: [PATCH 04/14] Add restriction to use subgroup matmul for prefill only --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 15215082c7480..878494d46862d 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -817,7 +817,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const bool has_zero_points = zero_points != nullptr; // macOS - Experimental dawn support for subgroup matrix matmul on Metal. - if (CanApplySubgroupMatrixMatMulNBits(context, block_size, batch_count, K, N, has_zero_points)) { + if (M >= kMinMForTileOptimization && + CanApplySubgroupMatrixMatMulNBits(context, block_size, batch_count, N, K, has_zero_points)) { return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } From 2d406da6dfa5dced8da62fbd2347b89612ca982e Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 17 Feb 2025 14:04:57 -0800 Subject: [PATCH 05/14] First round bug fixes --- .../subgroup_matrix_matmul_nbits.cc | 194 +++++++++--------- 1 file changed, 98 insertions(+), 96 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 35d0c0fc54c01..7d44032040105 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -21,29 +21,31 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader const subtile_rows = 16; const quantization_block_size = 32; - var tile_A: array; // 32 x 32 - RxC - var tile_B: array; // 32 x 64 - RxC - var scratch: array, 4>; // 64 * 4 + var tile_A: array; // 32 x 32 - RxC + var tile_B: array; // 64 x 32 - RxC + var scratch: array, 4>; // 64 * 4 - fn loadSHMA(subtile_base: u32, k_idx: u32, row: u32, col:u32) { - let a_global = subtile_base + row; + fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { + let a_global = tile_base + row; if (a_global >= uniforms.M) { return; } // Each call loads 8 columns, starting at col. + var col = c_idx * 8; // 128 threads need to load 32 x 32. 4 threads per row or 8 col per thread. for (var col_offset:u32 = 0; col_offset < 8; col_offset++) { - tile_A[row * tile_rows + col+ col_offset] = input_a[a_global*uniforms.K + k_idx + col + col_offset]; + tile_A[row * tile_k + col + col_offset] = input_a[a_global*uniforms.K + k_idx + col + col_offset]; } } - fn loadSHMB(subtile_base: u32, k_idx: u32, row: u32, col: u32) { - let b_global = subtile_base + row; + fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { + let b_global = tile_base + row; if (b_global >= uniforms.N) { return; } // Each call loads 16 columns, starting at col. + var col = c_idx * 16; // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. // Stored in column major fashion. let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); @@ -53,7 +55,7 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader var b_value = input_b[b_idx+step]; var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; - let tile_b_base = row * tile_rows + col + step * 8; + let tile_b_base = row * tile_k + col + step * 8; tile_B[tile_b_base] = b_value_lower[0]; tile_B[tile_b_base + 1] = b_value_upper[0]; tile_B[tile_b_base + 2] = b_value_lower[1]; @@ -81,95 +83,95 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( - let a_global_base = workgroup_id.y * tile_rows; - let b_global_base = workgroup_id.x * tile_cols; - - let subtile_id = u32(local_idx / sg_size); - let subtile_idx = u32(subtile_id / 2); - let subtile_idy = subtile_id % 2; - let base_A = subtile_idy * subtile_rows; - let base_B = subtile_idx * subtile_cols; - - var matC00: subgroup_matrix_result; - var matC01: subgroup_matrix_result; - var matC02: subgroup_matrix_result; - var matC03: subgroup_matrix_result; - var matC10: subgroup_matrix_result; - var matC11: subgroup_matrix_result; - var matC12: subgroup_matrix_result; - var matC13: subgroup_matrix_result; - for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { - // Load Phase - loadSHMA(a_global_base+base_A, kidx, local_idx/4, local_idx%4); - loadSHMB(b_global_base+base_B, kidx, local_idx/2, local_idx%2); - workgroupBarrier(); - - for (var step: u32 = 0; step < tile_k; step+=8) - { - // Load to local memory phase - let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; - // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride - var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); - var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); - - // tile_B is stored as column major. - // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] - var matrix_b_offset = subtile_idx * subtile_cols + step; - var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); - var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); - var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); - var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); - - // Compute Phase - // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate - matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); - matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); - matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); - matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); - - matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); - matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); - matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); - matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); - } - - workgroupBarrier(); - } - - let subtile_thread_id = sg_id; - var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; - if (a_global_base+base_A+8 < uniforms.M) - { - // Syntax: subgroupMatrixStore destination, dest_offset, matrix, is_col_major, dest_stride - subgroupMatrixStore(&output, matrix_c_offset, matC00, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 8, matC01, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 16, matC02, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 24, matC03, false, uniforms.N); - } - else if (a_global_base + base_A < uniforms.M) - { - let rows = uniforms.M - (a_global_base + base_A); - safeMatrixStore(matrix_c_offset, &matC00, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 8, &matC01, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 16, &matC02, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 24, &matC03, rows, subtile_id, subtile_thread_id); - } - matrix_c_offset = matrix_c_offset + 8 * uniforms.N; - if (a_global_base+base_A+16 < uniforms.M) + let a_global_base = workgroup_id.y * tile_rows; + let b_global_base = workgroup_id.x * tile_cols; + + let subtile_id = u32(local_idx / sg_size); + let subtile_idx = u32(subtile_id / 2); + let subtile_idy = subtile_id % 2; + let base_A = subtile_idy * subtile_rows; + let base_B = subtile_idx * subtile_cols; + + var matC00: subgroup_matrix_result; + var matC01: subgroup_matrix_result; + var matC02: subgroup_matrix_result; + var matC03: subgroup_matrix_result; + var matC10: subgroup_matrix_result; + var matC11: subgroup_matrix_result; + var matC12: subgroup_matrix_result; + var matC13: subgroup_matrix_result; + for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { + // Load Phase + loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); + loadSHMB(b_global_base, kidx, local_idx/2, local_idx%2); + workgroupBarrier(); + + for (var step: u32 = 0; step < tile_k; step+=8) { - subgroupMatrixStore(&output, matrix_c_offset, matC10, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 8, matC11, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 16, matC12, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 24, matC13, false, uniforms.N); - } - else if (a_global_base+base_A+8 < uniforms.M) - { - let rows = uniforms.M - (a_global_base + base_A + 8); - safeMatrixStore(matrix_c_offset, &matC10, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 8, &matC11, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 16, &matC12, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 24, &matC13, rows, subtile_id, subtile_thread_id); + // Load to local memory phase + let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; + // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); + + // tile_B is stored as column major. + // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] + var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; + var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); + var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); + var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); + var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); + matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); + matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); + matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); + + matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); + matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); + matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); + matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); } + + workgroupBarrier(); + } + + let subtile_thread_id = sg_id; + var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; + if (a_global_base + base_A + 8 < uniforms.M) + { + // Syntax: subgroupMatrixStore destination, dest_offset, matrix, is_col_major, dest_stride + subgroupMatrixStore(&output, matrix_c_offset, matC00, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 8, matC01, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 16, matC02, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 24, matC03, false, uniforms.N); + } + else if (a_global_base + base_A < uniforms.M) + { + let rows = uniforms.M - (a_global_base + base_A); + safeMatrixStore(matrix_c_offset, &matC00, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 8, &matC01, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 16, &matC02, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 24, &matC03, rows, subtile_id, subtile_thread_id); + } + matrix_c_offset = matrix_c_offset + 8 * uniforms.N; + if (a_global_base + base_A + 16 < uniforms.M) + { + subgroupMatrixStore(&output, matrix_c_offset, matC10, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 8, matC11, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 16, matC12, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 24, matC13, false, uniforms.N); + } + else if (a_global_base + base_A + 8 < uniforms.M) + { + let rows = uniforms.M - (a_global_base + base_A + 8); + safeMatrixStore(matrix_c_offset, &matC10, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 8, &matC11, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 16, &matC12, rows, subtile_id, subtile_thread_id); + safeMatrixStore(matrix_c_offset + 24, &matC13, rows, subtile_id, subtile_thread_id); + } )MAIN_FN"; return Status::OK(); From a3102cd11e14e83bf9d0dbc207b756e54d17974e Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 17 Feb 2025 14:21:47 -0800 Subject: [PATCH 06/14] Remove safeMatrixStore --- .../subgroup_matrix_matmul_nbits.cc | 168 +++++++----------- 1 file changed, 69 insertions(+), 99 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 7d44032040105..ed787984e590a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -23,7 +23,6 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader var tile_A: array; // 32 x 32 - RxC var tile_B: array; // 64 x 32 - RxC - var scratch: array, 4>; // 64 * 4 fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { let a_global = tile_base + row; @@ -66,112 +65,83 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader tile_B[tile_b_base + 7] = b_value_upper[3]; } } + )ADDNL_FN"; - fn safeMatrixStore(offset: u32, mat: ptr>, rows:u32, subtile_id:u32, subtile_thread_id:u32) - { - subgroupMatrixStore(&scratch[subtile_id], 0, *mat, false, 8); - // There are 32 subtile_thread_id and we have 64 values. - let row = u32(subtile_thread_id / 4); - var col = u32(subtile_thread_id % 4) * 2; - if (row < rows) + shader.MainFunctionBody() << R"MAIN_FN( + let a_global_base = workgroup_id.y * tile_rows; + let b_global_base = workgroup_id.x * tile_cols; + + let subtile_id = u32(local_idx / sg_size); + let subtile_idx = u32(subtile_id / 2); + let subtile_idy = subtile_id % 2; + let base_A = subtile_idy * subtile_rows; + let base_B = subtile_idx * subtile_cols; + + var matC00: subgroup_matrix_result; + var matC01: subgroup_matrix_result; + var matC02: subgroup_matrix_result; + var matC03: subgroup_matrix_result; + var matC10: subgroup_matrix_result; + var matC11: subgroup_matrix_result; + var matC12: subgroup_matrix_result; + var matC13: subgroup_matrix_result; + for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { + // Load Phase + loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); + loadSHMB(b_global_base, kidx, local_idx/2, local_idx%2); + workgroupBarrier(); + + for (var step: u32 = 0; step < tile_k; step+=8) { - output[offset + row * uniforms.N + col] = scratch[subtile_id][row * 8 + col]; - col++; - output[offset + row * uniforms.N + col] = scratch[subtile_id][row * 8 + col]; + // Load to local memory phase + let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; + // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); + + // tile_B is stored as column major. + // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] + var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; + var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); + var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); + var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); + var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); + matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); + matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); + matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); + + matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); + matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); + matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); + matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); } + + workgroupBarrier(); } - )ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( - let a_global_base = workgroup_id.y * tile_rows; - let b_global_base = workgroup_id.x * tile_cols; - - let subtile_id = u32(local_idx / sg_size); - let subtile_idx = u32(subtile_id / 2); - let subtile_idy = subtile_id % 2; - let base_A = subtile_idy * subtile_rows; - let base_B = subtile_idx * subtile_cols; - - var matC00: subgroup_matrix_result; - var matC01: subgroup_matrix_result; - var matC02: subgroup_matrix_result; - var matC03: subgroup_matrix_result; - var matC10: subgroup_matrix_result; - var matC11: subgroup_matrix_result; - var matC12: subgroup_matrix_result; - var matC13: subgroup_matrix_result; - for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { - // Load Phase - loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); - loadSHMB(b_global_base, kidx, local_idx/2, local_idx%2); - workgroupBarrier(); - - for (var step: u32 = 0; step < tile_k; step+=8) + let subtile_thread_id = sg_id; + var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; + if (a_global_base + base_A < uniforms.M) { - // Load to local memory phase - let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; - // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride - var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); - var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); - - // tile_B is stored as column major. - // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] - var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; - var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); - var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); - var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); - var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); - - // Compute Phase - // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate - matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); - matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); - matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); - matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); - - matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); - matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); - matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); - matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); + // Syntax: subgroupMatrixStore destination, dest_offset, matrix, is_col_major, dest_stride + subgroupMatrixStore(&output, matrix_c_offset, matC00, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 8, matC01, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 16, matC02, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 24, matC03, false, uniforms.N); } - workgroupBarrier(); - } - - let subtile_thread_id = sg_id; - var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; - if (a_global_base + base_A + 8 < uniforms.M) - { - // Syntax: subgroupMatrixStore destination, dest_offset, matrix, is_col_major, dest_stride - subgroupMatrixStore(&output, matrix_c_offset, matC00, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 8, matC01, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 16, matC02, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 24, matC03, false, uniforms.N); - } - else if (a_global_base + base_A < uniforms.M) - { - let rows = uniforms.M - (a_global_base + base_A); - safeMatrixStore(matrix_c_offset, &matC00, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 8, &matC01, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 16, &matC02, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 24, &matC03, rows, subtile_id, subtile_thread_id); - } - matrix_c_offset = matrix_c_offset + 8 * uniforms.N; - if (a_global_base + base_A + 16 < uniforms.M) - { - subgroupMatrixStore(&output, matrix_c_offset, matC10, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 8, matC11, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 16, matC12, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 24, matC13, false, uniforms.N); - } - else if (a_global_base + base_A + 8 < uniforms.M) - { - let rows = uniforms.M - (a_global_base + base_A + 8); - safeMatrixStore(matrix_c_offset, &matC10, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 8, &matC11, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 16, &matC12, rows, subtile_id, subtile_thread_id); - safeMatrixStore(matrix_c_offset + 24, &matC13, rows, subtile_id, subtile_thread_id); - } + matrix_c_offset = matrix_c_offset + 8 * uniforms.N; + if (a_global_base + base_A + 8 < uniforms.M) + { + subgroupMatrixStore(&output, matrix_c_offset, matC10, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 8, matC11, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 16, matC12, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + 24, matC13, false, uniforms.N); + } )MAIN_FN"; return Status::OK(); From 2866c269c43f7885e77ad50ed915603ec6713059 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 17 Feb 2025 21:13:58 -0800 Subject: [PATCH 07/14] Add FP32 support --- .../subgroup_matrix_matmul_nbits.cc | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index ed787984e590a..38b42dee6e17f 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -21,8 +21,8 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader const subtile_rows = 16; const quantization_block_size = 32; - var tile_A: array; // 32 x 32 - RxC - var tile_B: array; // 64 x 32 - RxC + var tile_A: array; // 32 x 32 - RxC + var tile_B: array; // 64 x 32 - RxC fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { let a_global = tile_base + row; @@ -52,8 +52,8 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader for (var step:u32 = 0; step < 2; step++) { var b_value = input_b[b_idx+step]; - var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; - var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; + var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; + var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; let tile_b_base = row * tile_k + col + step * 8; tile_B[tile_b_base] = b_value_lower[0]; tile_B[tile_b_base + 1] = b_value_upper[0]; @@ -77,14 +77,14 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader let base_A = subtile_idy * subtile_rows; let base_B = subtile_idx * subtile_cols; - var matC00: subgroup_matrix_result; - var matC01: subgroup_matrix_result; - var matC02: subgroup_matrix_result; - var matC03: subgroup_matrix_result; - var matC10: subgroup_matrix_result; - var matC11: subgroup_matrix_result; - var matC12: subgroup_matrix_result; - var matC13: subgroup_matrix_result; + var matC00: subgroup_matrix_result; + var matC01: subgroup_matrix_result; + var matC02: subgroup_matrix_result; + var matC03: subgroup_matrix_result; + var matC10: subgroup_matrix_result; + var matC11: subgroup_matrix_result; + var matC12: subgroup_matrix_result; + var matC13: subgroup_matrix_result; for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { // Load Phase loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); @@ -96,16 +96,16 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader // Load to local memory phase let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride - var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); - var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); // tile_B is stored as column major. // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; - var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); - var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); - var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); - var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); + var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); + var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); + var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); + var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); // Compute Phase // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate From acc3ab656d62895e5e40c267f9cf7cd6a9b22212 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 18 Feb 2025 14:59:29 -0800 Subject: [PATCH 08/14] Add support for compute precision --- .../subgroup_matrix_matmul_nbits.cc | 95 +++++++++++-------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 38b42dee6e17f..3a1abaac4902b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -20,9 +20,11 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader const subtile_cols = 32; const subtile_rows = 16; const quantization_block_size = 32; + alias compute_precision = output_element_t; - var tile_A: array; // 32 x 32 - RxC - var tile_B: array; // 64 x 32 - RxC + var tile_A: array; // 32 x 32 - RxC + var tile_B: array; // 64 x 32 - RxC + var scratch: array, 4>, 4>; // 64 * 4 * 4 fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { let a_global = tile_base + row; @@ -34,7 +36,7 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader // 128 threads need to load 32 x 32. 4 threads per row or 8 col per thread. for (var col_offset:u32 = 0; col_offset < 8; col_offset++) { - tile_A[row * tile_k + col + col_offset] = input_a[a_global*uniforms.K + k_idx + col + col_offset]; + tile_A[row * tile_k + col + col_offset] = compute_precision(input_a[a_global*uniforms.K + k_idx + col + col_offset]); } } @@ -48,12 +50,12 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. // Stored in column major fashion. let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); - let scale = scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]; + let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); for (var step:u32 = 0; step < 2; step++) { var b_value = input_b[b_idx+step]; - var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; - var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; + var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; + var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; let tile_b_base = row * tile_k + col + step * 8; tile_B[tile_b_base] = b_value_lower[0]; tile_B[tile_b_base + 1] = b_value_upper[0]; @@ -65,6 +67,21 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader tile_B[tile_b_base + 7] = b_value_upper[3]; } } + + fn storeOutput(offset:u32, row: u32, col:u32, src_slot:u32, row_limit:i32) { + if (row_limit > 0 && row < u32(row_limit)) + { + output[offset + row * uniforms.N + col] = output_element_t(scratch[src_slot][0][row * 8 + col]); + output[offset + row * uniforms.N + col + 8] = output_element_t(scratch[src_slot][1][row * 8 + col]); + output[offset + row * uniforms.N + col + 16] = output_element_t(scratch[src_slot][2][row * 8 + col]); + output[offset + row * uniforms.N + col + 24] = output_element_t(scratch[src_slot][3][row * 8 + col]); + let col2 = col + 1; + output[offset + row * uniforms.N + col2] = output_element_t(scratch[src_slot][0][row * 8 + col2]); + output[offset + row * uniforms.N + col2 + 8] = output_element_t(scratch[src_slot][1][row * 8 + col2]); + output[offset + row * uniforms.N + col2 + 16] = output_element_t(scratch[src_slot][2][row * 8 + col2]); + output[offset + row * uniforms.N + col2 + 24] = output_element_t(scratch[src_slot][3][row * 8 + col2]); + } + } )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( @@ -77,14 +94,14 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader let base_A = subtile_idy * subtile_rows; let base_B = subtile_idx * subtile_cols; - var matC00: subgroup_matrix_result; - var matC01: subgroup_matrix_result; - var matC02: subgroup_matrix_result; - var matC03: subgroup_matrix_result; - var matC10: subgroup_matrix_result; - var matC11: subgroup_matrix_result; - var matC12: subgroup_matrix_result; - var matC13: subgroup_matrix_result; + var matC00: subgroup_matrix_result; + var matC01: subgroup_matrix_result; + var matC02: subgroup_matrix_result; + var matC03: subgroup_matrix_result; + var matC10: subgroup_matrix_result; + var matC11: subgroup_matrix_result; + var matC12: subgroup_matrix_result; + var matC13: subgroup_matrix_result; for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { // Load Phase loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); @@ -96,16 +113,16 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader // Load to local memory phase let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride - var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); - var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); // tile_B is stored as column major. // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; - var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); - var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); - var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); - var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); + var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); + var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); + var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); + var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); // Compute Phase // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate @@ -123,25 +140,29 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader workgroupBarrier(); } - let subtile_thread_id = sg_id; + // Write out + // Write out top block + subgroupMatrixStore(&scratch[subtile_id][0], 0, matC00, false, 8); + subgroupMatrixStore(&scratch[subtile_id][1], 0, matC01, false, 8); + subgroupMatrixStore(&scratch[subtile_id][2], 0, matC02, false, 8); + subgroupMatrixStore(&scratch[subtile_id][3], 0, matC03, false, 8); + workgroupBarrier(); + let row = u32(sg_id / 4); + var col = u32(sg_id % 4) * 2; var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; - if (a_global_base + base_A < uniforms.M) - { - // Syntax: subgroupMatrixStore destination, dest_offset, matrix, is_col_major, dest_stride - subgroupMatrixStore(&output, matrix_c_offset, matC00, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 8, matC01, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 16, matC02, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 24, matC03, false, uniforms.N); - } - + var row_limit:i32 = i32(uniforms.M) - i32(a_global_base + base_A); + storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); + workgroupBarrier(); + + // Write out bottom block + subgroupMatrixStore(&scratch[subtile_id][0], 0, matC10, false, 8); + subgroupMatrixStore(&scratch[subtile_id][1], 0, matC11, false, 8); + subgroupMatrixStore(&scratch[subtile_id][2], 0, matC12, false, 8); + subgroupMatrixStore(&scratch[subtile_id][3], 0, matC13, false, 8); + workgroupBarrier(); matrix_c_offset = matrix_c_offset + 8 * uniforms.N; - if (a_global_base + base_A + 8 < uniforms.M) - { - subgroupMatrixStore(&output, matrix_c_offset, matC10, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 8, matC11, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 16, matC12, false, uniforms.N); - subgroupMatrixStore(&output, matrix_c_offset + 24, matC13, false, uniforms.N); - } + row_limit = i32(uniforms.M) - i32(a_global_base + base_A + 8); + storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); )MAIN_FN"; return Status::OK(); From 2ad480cc244860bf7358281865af2260de1f827a Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 18 Feb 2025 15:18:29 -0800 Subject: [PATCH 09/14] Restrict to accuracy level 4 --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 2 +- .../webgpu/quantization/subgroup_matrix_matmul_nbits.cc | 6 ++++++ .../webgpu/quantization/subgroup_matrix_matmul_nbits.h | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 878494d46862d..8f335d78ac6df 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -818,7 +818,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const bool has_zero_points = zero_points != nullptr; // macOS - Experimental dawn support for subgroup matrix matmul on Metal. if (M >= kMinMForTileOptimization && - CanApplySubgroupMatrixMatMulNBits(context, block_size, batch_count, N, K, has_zero_points)) { + CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) { return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 3a1abaac4902b..8f7a55e5f7fdd 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -196,6 +196,7 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te } bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, uint32_t block_size, uint32_t batch_count, uint32_t N, @@ -203,8 +204,13 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont bool has_zero_points) { const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are + // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy + // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, + // FP322 is around 7s. return context.AdapterInfo().backendType == wgpu::BackendType::Metal && has_subgroup_matrix && + accuracy_level == 4 && block_size == 32 && batch_count == 1 && K % 32 == 0 && diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index 65988d3970b5e..c64d343b6d576 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -33,6 +33,7 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te Tensor* y); bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, uint32_t block_size, uint32_t batch_count, uint32_t N, From 8d7ae124b61a08b760dc158e6f18036da7259fc9 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 18 Feb 2025 15:21:48 -0800 Subject: [PATCH 10/14] Lintrunner --- .gitignore | 200 +----------------- .../webgpu/quantization/matmul_nbits.cc | 2 +- .../subgroup_matrix_matmul_nbits.cc | 101 +++++---- .../subgroup_matrix_matmul_nbits.h | 24 +-- .../core/providers/webgpu/webgpu_context.h | 6 +- 5 files changed, 67 insertions(+), 266 deletions(-) diff --git a/.gitignore b/.gitignore index 4d0a1205b7c19..f514b74c5f2fb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,198 +1,2 @@ -# build, distribute, and bins (+ python proto bindings) -build -build_*/ -.build_debug/* -.build_release/* -distribute/* -*.testbin -*.bin -cmake_build -.cmake_build -cmake-build-debug -gen -*~ -.vs -.DS_Store -*.DS_Store -TestResults/ -.idea/ -onnxruntime.egg-info -nuget_root/ -.packages/ -.vscode -*.code-workspace -__pycache__ -onnxruntime_profile*.json -/docs/python/inference/*.md -/docs/python/inference/auto_examples/* -/docs/python/inference/media/* -/docs/python/inference/examples/*.onnx -/docs/python/inference/examples/graph.* -/docs/python/*_LICENSE -/LICENSE.txt -/csharp/**/obj/ -/csharp/**/bin/ -/csharp/Directory.Build.props -docs/python/inference/*.onnx -*.onnx -onnxprofile_profile_test_*.json -/csharp/packages -/csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.targets -/csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.props -/csharp/**/*.vcxproj.user -cmake/external/FeaturizersLibrary/ -# Java specific ignores -java/.gradle -java/hs_*.log -/java/bin -onnxruntime/python/version_info.py -/orttraining/orttraining/eager/ort_aten.g.cpp -/orttraining/orttraining/eager/ort_customops.g.cpp -/csharp/**/packages -# direnv, posh-direnv -.envrc -.psenvrc -*.csproj.user -# clangd -.cache/ -compile_commands.json -# Rust specific -rust/**/target -rust/**/Cargo.lock -rust/onnxruntime/synset.txt - -# Python - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# Swift Package Manager -Packages/ -Package.pins -Package.resolved -.build/ -.swiftpm/ -repros/ +# Created by venv; see https://docs.python.org/3/library/venv.html +* diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 8f335d78ac6df..28d622b2c9c33 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -818,7 +818,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const bool has_zero_points = zero_points != nullptr; // macOS - Experimental dawn support for subgroup matrix matmul on Metal. if (M >= kMinMForTileOptimization && - CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) { + CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) { return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 8f7a55e5f7fdd..2d4bdfd59b1ae 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -8,12 +8,12 @@ namespace contrib { namespace webgpu { Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("input_b", ShaderUsage::UseUniform); - shader.AddInput("scales_b", ShaderUsage::UseUniform); - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - shader.AdditionalImplementation() << R"ADDNL_FN( + shader.AdditionalImplementation() << R"ADDNL_FN( const tile_cols = 64; const tile_rows = 32; const tile_k = 32; @@ -84,7 +84,7 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader } )ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( + shader.MainFunctionBody() << R"MAIN_FN( let a_global_base = workgroup_id.y * tile_rows; let b_global_base = workgroup_id.x * tile_cols; @@ -165,58 +165,55 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); )MAIN_FN"; - return Status::OK(); + return Status::OK(); } - Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, - uint32_t M, - uint32_t N, - uint32_t K, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) -{ - constexpr uint32_t kTileSizeA = 32; - constexpr uint32_t kTileSizeB = 64; - constexpr uint32_t kU32Components = 4; - TensorShape y_shape{1, M, N}; - SubgroupMatrixMatMulNBitsProgram mul_program; - mul_program.SetWorkgroupSize(128); - mul_program.SetDispatchGroupSize( - (N + kTileSizeB - 1) / kTileSizeB, - (M + kTileSizeA - 1) / kTileSizeA, 1); - mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) - .AddUniformVariables({{static_cast(M)}, + uint32_t M, + uint32_t N, + uint32_t K, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + constexpr uint32_t kTileSizeA = 32; + constexpr uint32_t kTileSizeB = 64; + constexpr uint32_t kU32Components = 4; + TensorShape y_shape{1, M, N}; + SubgroupMatrixMatMulNBitsProgram mul_program; + mul_program.SetWorkgroupSize(128); + mul_program.SetDispatchGroupSize( + (N + kTileSizeB - 1) / kTileSizeB, + (M + kTileSizeA - 1) / kTileSizeA, 1); + mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, {static_cast(K)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow(1)}); - return context.RunProgram(mul_program); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow(1)}); + return context.RunProgram(mul_program); } bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, - uint64_t accuracy_level, - uint32_t block_size, - uint32_t batch_count, - uint32_t N, - uint32_t K, - bool has_zero_points) -{ - const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); - // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are - // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy - // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, - // FP322 is around 7s. - return context.AdapterInfo().backendType == wgpu::BackendType::Metal && - has_subgroup_matrix && - accuracy_level == 4 && - block_size == 32 && - batch_count == 1 && - K % 32 == 0 && - N % 64 == 0 && - !has_zero_points; + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + bool has_zero_points) { + const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are + // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy + // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, + // FP322 is around 7s. + return context.AdapterInfo().backendType == wgpu::BackendType::Metal && + has_subgroup_matrix && + accuracy_level == 4 && + block_size == 32 && + batch_count == 1 && + K % 32 == 0 && + N % 64 == 0 && + !has_zero_points; } -} // namespace webgpu -} // namespace contrib -} // namespace onnxruntime +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index c64d343b6d576..57a0b1066326a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -17,7 +17,7 @@ using namespace onnxruntime::webgpu; class SubgroupMatrixMatMulNBitsProgram final : public Program { public: - SubgroupMatrixMatMulNBitsProgram() : Program{"SubgroupMatrixMatMulNBits"} {} + SubgroupMatrixMatMulNBitsProgram() : Program{"SubgroupMatrixMatMulNBits"} {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -26,19 +26,19 @@ class SubgroupMatrixMatMulNBitsProgram final : public Program Date: Tue, 18 Feb 2025 16:25:47 -0800 Subject: [PATCH 11/14] Revert change to .gitignore --- .gitignore | 200 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 198 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index f514b74c5f2fb..4d0a1205b7c19 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,198 @@ -# Created by venv; see https://docs.python.org/3/library/venv.html -* +# build, distribute, and bins (+ python proto bindings) +build +build_*/ +.build_debug/* +.build_release/* +distribute/* +*.testbin +*.bin +cmake_build +.cmake_build +cmake-build-debug +gen +*~ +.vs +.DS_Store +*.DS_Store +TestResults/ +.idea/ +onnxruntime.egg-info +nuget_root/ +.packages/ +.vscode +*.code-workspace +__pycache__ +onnxruntime_profile*.json +/docs/python/inference/*.md +/docs/python/inference/auto_examples/* +/docs/python/inference/media/* +/docs/python/inference/examples/*.onnx +/docs/python/inference/examples/graph.* +/docs/python/*_LICENSE +/LICENSE.txt +/csharp/**/obj/ +/csharp/**/bin/ +/csharp/Directory.Build.props +docs/python/inference/*.onnx +*.onnx +onnxprofile_profile_test_*.json +/csharp/packages +/csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.targets +/csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.props +/csharp/**/*.vcxproj.user +cmake/external/FeaturizersLibrary/ +# Java specific ignores +java/.gradle +java/hs_*.log +/java/bin +onnxruntime/python/version_info.py +/orttraining/orttraining/eager/ort_aten.g.cpp +/orttraining/orttraining/eager/ort_customops.g.cpp +/csharp/**/packages +# direnv, posh-direnv +.envrc +.psenvrc +*.csproj.user +# clangd +.cache/ +compile_commands.json +# Rust specific +rust/**/target +rust/**/Cargo.lock +rust/onnxruntime/synset.txt + +# Python + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Swift Package Manager +Packages/ +Package.pins +Package.resolved +.build/ +.swiftpm/ +repros/ From aa6e8c0792276815c15467154fb409f18f5d55d0 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 18 Feb 2025 17:53:28 -0800 Subject: [PATCH 12/14] Fix mac build break --- onnxruntime/core/providers/webgpu/webgpu_context.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 7b68af6d620da..cb0e14f82610b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -95,9 +95,10 @@ class WebGpuContext final { if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) { wgpu::PassTimestampWrites timestampWrites = { - .querySet = query_set_, - .beginningOfPassWriteIndex = num_pending_dispatches_ * 2, - .endOfPassWriteIndex = num_pending_dispatches_ * 2 + 1}; + nullptr, + query_set_, + num_pending_dispatches_ * 2, + num_pending_dispatches_ * 2 + 1}; compute_pass_desc.timestampWrites = ×tampWrites; } From 09e30beb5086bfcab51c20c329aa033106da6fd0 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Wed, 19 Feb 2025 12:37:08 -0800 Subject: [PATCH 13/14] add comments --- .../webgpu/quantization/subgroup_matrix_matmul_nbits.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 2d4bdfd59b1ae..9c9990713e97d 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -13,6 +13,8 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader shader.AddInput("scales_b", ShaderUsage::UseUniform); shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + // tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm) + // https://github.com/ggml-org/llama.cpp/blob/d04e7163c85a847bc61d58c22f2c503596db7aa8/ggml/src/ggml-metal/ggml-metal.metal#L6066 shader.AdditionalImplementation() << R"ADDNL_FN( const tile_cols = 64; const tile_rows = 32; From 91b7eb64a06e0a823618344485b65a88bf082bd8 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Wed, 19 Feb 2025 13:54:49 -0800 Subject: [PATCH 14/14] Subgroupmatmul memory optimzations --- .gitignore | 3 ++ .../subgroup_matrix_matmul_nbits.cc | 48 +++++++++---------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 4d0a1205b7c19..a3d0cd5745e65 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,9 @@ rust/**/Cargo.lock rust/onnxruntime/synset.txt # Python +bin/ +lib/ +pyvenv.cfg # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 9c9990713e97d..5d89a511e000e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -28,37 +28,33 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader var tile_B: array; // 64 x 32 - RxC var scratch: array, 4>, 4>; // 64 * 4 * 4 - fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { - let a_global = tile_base + row; - if (a_global >= uniforms.M) { - return; - } - // Each call loads 8 columns, starting at col. - var col = c_idx * 8; - // 128 threads need to load 32 x 32. 4 threads per row or 8 col per thread. - for (var col_offset:u32 = 0; col_offset < 8; col_offset++) + fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, col:u32) { + // Each call loads 4 rows, at col. + // 128 threads need to load 32 x 32. Organized as 32 threads load a row, 4 rows at a time. + var out_row = row; + let row_limit = min(tile_base + tile_rows, uniforms.M); + for (var a_global:u32 = tile_base + out_row; a_global < row_limit; a_global = tile_base + out_row) { - tile_A[row * tile_k + col + col_offset] = compute_precision(input_a[a_global*uniforms.K + k_idx + col + col_offset]); + tile_A[out_row * tile_k + col] = compute_precision(input_a[a_global*uniforms.K + k_idx + col]); + out_row+=4; } } fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { - let b_global = tile_base + row; - if (b_global >= uniforms.N) { - return; - } - // Each call loads 16 columns, starting at col. - var col = c_idx * 16; - // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. - // Stored in column major fashion. - let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); - let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); - for (var step:u32 = 0; step < 2; step++) + // Each call loads 8 columns, starting at col and then loads another rows. + // 128 threads need to load 64 x 32. Organized as 4 threads load a row, 32 rows at a time. + // B is stored in column major fashion. + var col = c_idx * 8; + let out_row_limit = min(tile_base + tile_cols, uniforms.N); + var out_row = row; + for (var b_global:u32 = tile_base + out_row; b_global < out_row_limit; b_global = tile_base + out_row) { - var b_value = input_b[b_idx+step]; + let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); + let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); + var b_value = input_b[b_idx]; var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; - let tile_b_base = row * tile_k + col + step * 8; + let tile_b_base = out_row * tile_k + col; tile_B[tile_b_base] = b_value_lower[0]; tile_B[tile_b_base + 1] = b_value_upper[0]; tile_B[tile_b_base + 2] = b_value_lower[1]; @@ -67,6 +63,7 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader tile_B[tile_b_base + 5] = b_value_upper[2]; tile_B[tile_b_base + 6] = b_value_lower[3]; tile_B[tile_b_base + 7] = b_value_upper[3]; + out_row += 32; } } @@ -106,8 +103,9 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader var matC13: subgroup_matrix_result; for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { // Load Phase - loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); - loadSHMB(b_global_base, kidx, local_idx/2, local_idx%2); + loadSHMA(a_global_base, kidx, local_idx/tile_k, local_idx%tile_k); + // tile_k in B's vectorization is 4. Since each U32 holds 8 weights. + loadSHMB(b_global_base, kidx, local_idx/4, local_idx%4); workgroupBarrier(); for (var step: u32 = 0; step < tile_k; step+=8)