diff --git a/.gitignore b/.gitignore index 4d0a1205b7c19..a3d0cd5745e65 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,9 @@ rust/**/Cargo.lock rust/onnxruntime/synset.txt # Python +bin/ +lib/ +pyvenv.cfg # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/cmake/deps.txt b/cmake/deps.txt index b07a3acdecd54..d2e0fd63215f4 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -58,5 +58,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1 -dawn;https://github.com/google/dawn/archive/b9b4a37041dec3dd62ac92014a6cc1aece48d9f3.zip;e8b8c2ebabdedb7c57d931fc4a19ae22146d31e1 +dawn;https://github.com/google/dawn/archive/40a9fa79f76e6c76cca9e2fa69ea07f202f1d2e6.zip;e224563d5ab4a8e53a517b06f721242533bce722 kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/d15722976120710080ca098fe8ddabf4556cb40f/kleidiai-d15722976120710080ca098fe8ddabf4556cb40f.zip;d6c840d00c3b05aedf06e957ddaece1013d1f40b diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 3b566d37fa979..28d622b2c9c33 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -4,6 +4,7 @@ #include #include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" @@ -815,6 +816,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context uint32_t components = GetMaxComponents(N); const bool has_zero_points = zero_points != nullptr; + // macOS - Experimental dawn support for subgroup matrix matmul on Metal. + if (M >= kMinMForTileOptimization && + CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) { + return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); + } + const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc new file mode 100644 index 0000000000000..5d89a511e000e --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + // tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm) + // https://github.com/ggml-org/llama.cpp/blob/d04e7163c85a847bc61d58c22f2c503596db7aa8/ggml/src/ggml-metal/ggml-metal.metal#L6066 + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_cols = 64; + const tile_rows = 32; + const tile_k = 32; + const subtile_cols = 32; + const subtile_rows = 16; + const quantization_block_size = 32; + alias compute_precision = output_element_t; + + var tile_A: array; // 32 x 32 - RxC + var tile_B: array; // 64 x 32 - RxC + var scratch: array, 4>, 4>; // 64 * 4 * 4 + + fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, col:u32) { + // Each call loads 4 rows, at col. + // 128 threads need to load 32 x 32. Organized as 32 threads load a row, 4 rows at a time. + var out_row = row; + let row_limit = min(tile_base + tile_rows, uniforms.M); + for (var a_global:u32 = tile_base + out_row; a_global < row_limit; a_global = tile_base + out_row) + { + tile_A[out_row * tile_k + col] = compute_precision(input_a[a_global*uniforms.K + k_idx + col]); + out_row+=4; + } + } + + fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { + // Each call loads 8 columns, starting at col and then loads another rows. + // 128 threads need to load 64 x 32. Organized as 4 threads load a row, 32 rows at a time. + // B is stored in column major fashion. + var col = c_idx * 8; + let out_row_limit = min(tile_base + tile_cols, uniforms.N); + var out_row = row; + for (var b_global:u32 = tile_base + out_row; b_global < out_row_limit; b_global = tile_base + out_row) + { + let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); + let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); + var b_value = input_b[b_idx]; + var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; + var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; + let tile_b_base = out_row * tile_k + col; + tile_B[tile_b_base] = b_value_lower[0]; + tile_B[tile_b_base + 1] = b_value_upper[0]; + tile_B[tile_b_base + 2] = b_value_lower[1]; + tile_B[tile_b_base + 3] = b_value_upper[1]; + tile_B[tile_b_base + 4] = b_value_lower[2]; + tile_B[tile_b_base + 5] = b_value_upper[2]; + tile_B[tile_b_base + 6] = b_value_lower[3]; + tile_B[tile_b_base + 7] = b_value_upper[3]; + out_row += 32; + } + } + + fn storeOutput(offset:u32, row: u32, col:u32, src_slot:u32, row_limit:i32) { + if (row_limit > 0 && row < u32(row_limit)) + { + output[offset + row * uniforms.N + col] = output_element_t(scratch[src_slot][0][row * 8 + col]); + output[offset + row * uniforms.N + col + 8] = output_element_t(scratch[src_slot][1][row * 8 + col]); + output[offset + row * uniforms.N + col + 16] = output_element_t(scratch[src_slot][2][row * 8 + col]); + output[offset + row * uniforms.N + col + 24] = output_element_t(scratch[src_slot][3][row * 8 + col]); + let col2 = col + 1; + output[offset + row * uniforms.N + col2] = output_element_t(scratch[src_slot][0][row * 8 + col2]); + output[offset + row * uniforms.N + col2 + 8] = output_element_t(scratch[src_slot][1][row * 8 + col2]); + output[offset + row * uniforms.N + col2 + 16] = output_element_t(scratch[src_slot][2][row * 8 + col2]); + output[offset + row * uniforms.N + col2 + 24] = output_element_t(scratch[src_slot][3][row * 8 + col2]); + } + } + )ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + let a_global_base = workgroup_id.y * tile_rows; + let b_global_base = workgroup_id.x * tile_cols; + + let subtile_id = u32(local_idx / sg_size); + let subtile_idx = u32(subtile_id / 2); + let subtile_idy = subtile_id % 2; + let base_A = subtile_idy * subtile_rows; + let base_B = subtile_idx * subtile_cols; + + var matC00: subgroup_matrix_result; + var matC01: subgroup_matrix_result; + var matC02: subgroup_matrix_result; + var matC03: subgroup_matrix_result; + var matC10: subgroup_matrix_result; + var matC11: subgroup_matrix_result; + var matC12: subgroup_matrix_result; + var matC13: subgroup_matrix_result; + for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { + // Load Phase + loadSHMA(a_global_base, kidx, local_idx/tile_k, local_idx%tile_k); + // tile_k in B's vectorization is 4. Since each U32 holds 8 weights. + loadSHMB(b_global_base, kidx, local_idx/4, local_idx%4); + workgroupBarrier(); + + for (var step: u32 = 0; step < tile_k; step+=8) + { + // Load to local memory phase + let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; + // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); + + // tile_B is stored as column major. + // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] + var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; + var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); + var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); + var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); + var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); + matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); + matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); + matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); + + matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); + matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); + matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); + matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); + } + + workgroupBarrier(); + } + + // Write out + // Write out top block + subgroupMatrixStore(&scratch[subtile_id][0], 0, matC00, false, 8); + subgroupMatrixStore(&scratch[subtile_id][1], 0, matC01, false, 8); + subgroupMatrixStore(&scratch[subtile_id][2], 0, matC02, false, 8); + subgroupMatrixStore(&scratch[subtile_id][3], 0, matC03, false, 8); + workgroupBarrier(); + let row = u32(sg_id / 4); + var col = u32(sg_id % 4) * 2; + var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; + var row_limit:i32 = i32(uniforms.M) - i32(a_global_base + base_A); + storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); + workgroupBarrier(); + + // Write out bottom block + subgroupMatrixStore(&scratch[subtile_id][0], 0, matC10, false, 8); + subgroupMatrixStore(&scratch[subtile_id][1], 0, matC11, false, 8); + subgroupMatrixStore(&scratch[subtile_id][2], 0, matC12, false, 8); + subgroupMatrixStore(&scratch[subtile_id][3], 0, matC13, false, 8); + workgroupBarrier(); + matrix_c_offset = matrix_c_offset + 8 * uniforms.N; + row_limit = i32(uniforms.M) - i32(a_global_base + base_A + 8); + storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); + )MAIN_FN"; + + return Status::OK(); +} + +Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + constexpr uint32_t kTileSizeA = 32; + constexpr uint32_t kTileSizeB = 64; + constexpr uint32_t kU32Components = 4; + TensorShape y_shape{1, M, N}; + SubgroupMatrixMatMulNBitsProgram mul_program; + mul_program.SetWorkgroupSize(128); + mul_program.SetDispatchGroupSize( + (N + kTileSizeB - 1) / kTileSizeB, + (M + kTileSizeA - 1) / kTileSizeA, 1); + mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow(1)}); + return context.RunProgram(mul_program); +} + +bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + bool has_zero_points) { + const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are + // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy + // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, + // FP322 is around 7s. + return context.AdapterInfo().backendType == wgpu::BackendType::Metal && + has_subgroup_matrix && + accuracy_level == 4 && + block_size == 32 && + batch_count == 1 && + K % 32 == 0 && + N % 64 == 0 && + !has_zero_points; +} +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h new file mode 100644 index 0000000000000..57a0b1066326a --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class SubgroupMatrixMatMulNBitsProgram final : public Program { + public: + SubgroupMatrixMatMulNBitsProgram() : Program{"SubgroupMatrixMatMulNBits"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}); +}; + +Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y); + +bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + bool has_zero_points); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 49c9a84f69551..6ff10ffb80d2c 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -352,6 +352,9 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { ss << "enable subgroups;\n"; } + if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + ss << "enable chromium_experimental_subgroup_matrix;\n"; + } // // Section constants diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 50ace96524ddf..163dd691b7f16 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -486,6 +486,7 @@ std::vector WebGpuContext::GetAvailableRequiredFeatures(const constexpr wgpu::FeatureName features[]{ #if !defined(__wasm__) wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix, #endif wgpu::FeatureName::TimestampQuery, wgpu::FeatureName::ShaderF16, @@ -738,7 +739,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co // Step.2 - Create wgpu::Instance #if !defined(__wasm__) wgpu::InstanceDescriptor instance_desc{}; - instance_desc.features.timedWaitAnyEnable = true; + instance_desc.capabilities.timedWaitAnyEnable = true; default_instance_ = wgpu::CreateInstance(&instance_desc); #else default_instance_ = wgpu::CreateInstance(nullptr); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index a87a940f44437..cb0e14f82610b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -94,8 +94,11 @@ class WebGpuContext final { wgpu::ComputePassDescriptor compute_pass_desc{}; if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) { - wgpu::ComputePassTimestampWrites timestampWrites = { - query_set_, num_pending_dispatches_ * 2, num_pending_dispatches_ * 2 + 1}; + wgpu::PassTimestampWrites timestampWrites = { + nullptr, + query_set_, + num_pending_dispatches_ * 2, + num_pending_dispatches_ * 2 + 1}; compute_pass_desc.timestampWrites = ×tampWrites; }