diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0bdee151d2173..4cc5a4228dc8c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -11,18 +11,19 @@ namespace onnxruntime { namespace contrib { namespace group_query_attention_helper { -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + const T* cos_cache, + const T* sin_cache, void* parameters, int num_heads, int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, + const T* seqlens_k, + const T* total_seqlen, float scale, float softcap) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache @@ -265,18 +266,19 @@ Status CheckInputs(const Tensor* query, return Status::OK(); } -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + const T* cos_cache, + const T* sin_cache, void* parameters, int num_heads, int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, + const T* seqlens_k, + const T* total_seqlen, float scale, float softcap, int max_threads_per_block) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc new file mode 100644 index 0000000000000..ea8aa95614b40 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -0,0 +1,500 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("qkv_input", ShaderUsage::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" + << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" + << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; + if (has_bias_) { + shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; + } + shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; + if (has_bias_) { + shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; + } else { + shader.MainFunctionBody() << ";\n"; + } + + return Status::OK(); +} + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + ORT_ENFORCE(input_tensor->Shape().GetDims().size() == 3); + ORT_ENFORCE(output_tensor->Shape().GetDims().size() == 4); + + uint32_t data_size = SafeInt(output_tensor->Shape().Size()); + const int batch_offset = num_heads * sequence_length * head_size; + const int sequence_offset = num_heads * head_size; + const int head_offset = head_size; + bool has_bias = bias != nullptr; + + TransferBSDToBNSHProgram program{has_bias}; + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)}}); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + + return context.RunProgram(program); +}; + +void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) { + if (seqlen_k != nullptr) { + ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; + ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n"; + } else { + ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; + } +} + +Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_key_) { + shader.AddInput("past_key", ShaderUsage::UseUniform); + } + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + if (seqlen_k_ != nullptr) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (has_present_key_) { + shader.AddOutput("present_key", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" + << "let m = workgroup_id.y * TILE_SIZE;\n" + << "let n = workgroup_id.x * TILE_SIZE;\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" + << "let sequence_length = uniforms.M;\n" + << "var total_sequence_length = uniforms.N;\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str(); + if (n_reps_ > 1) { + shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let kv_head_idx = head_idx / uniforms.n_reps;\n" + << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" + << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" + << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n"; + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n"; + } + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n"; + } + } else { + shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; + } + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; + } + } + + shader.MainFunctionBody() << "var value = f32_val_t(0);\n" + "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + " }\n" + " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { + shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" + << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" + " tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " }\n"; + } + + if (has_present_key_) { + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" + << " }\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" + << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" + << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + + shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; + if (has_attention_bias_) { + shader.MainFunctionBody() << " + attention_bias[outputIdx]"; + } + shader.MainFunctionBody() << ";\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length, + const Tensor* seqlen_k) { + const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) + : parameters.scale_; + + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; + const bool has_present_key = output_count > 1 && past_key; + const bool has_attention_bias = attention_bias != nullptr; + constexpr int tile_size = 12; + const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); + + AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (feed_past_key) { + program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); + if (has_present_key) { + program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); + } + + const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; + program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, + (parameters.sequence_length_ + tile_size - 1) / tile_size, + parameters.batch_size_ * parameters.num_heads_) + .SetWorkgroupSize(tile_size, tile_size) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + if (seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AdditionalImplementation() << "var thread_max: array;\n" + << "var thread_sum: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let sequence_length = uniforms.sequence_length;\n" + << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str() + << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = f32_val_t(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n" + << " }\n" + << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " var f32input = f32_val_t(x[offset + i]);\n" + << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << "}\n"; + if (seqlen_k_) { + shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n" + << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" + << "}\n"; + } + + return Status::OK(); +} + +Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, + const Tensor* seqlen_k, bool is_first_prompt) { + const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); + int work_group_size = 64; + const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; + if (total_sequence_length_comp < work_group_size) { + work_group_size = 32; + } + const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; + + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k}; + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .CacheHint(work_group_size, is_first_prompt) + .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(batch_size)}, + {static_cast(num_heads)}, + {static_cast(past_sequence_length)}, + {static_cast(sequence_length)}, + {static_cast(total_sequence_length_comp)}, + {static_cast(elementsPerThread)}}); + + return context.RunProgram(program); +} + +Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_value_) { + shader.AddInput("past_value", ShaderUsage::UseUniform); + } + if (seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_present_value_) { + shader.AddOutput("present_value", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n"; + shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n" + << "let sequence_length = uniforms.M;\n" + << "var total_sequence_length = uniforms.K;\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str(); + if (n_reps_ > 1) { + shader.MainFunctionBody() << "let kv_head_idx = head_idx / uniforms.n_reps;\n" + << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" + << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" + << "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.present_sequence_length + n;\n"; + } + + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.present_sequence_length + n;\n"; + } + } else { + shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n"; + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.past_sequence_length + n;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.present_sequence_length + n;\n"; + } + + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = workgroup_id.z * uniforms.N * uniforms.present_sequence_length + n;\n"; + } + } + + shader.MainFunctionBody() << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << " }\n" + << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { + shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" + << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y) * uniforms.N];\n" + << " }\n"; + } + + if (has_present_value_) { + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" + << " }\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {\n" + << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n" + << " output[outputIdx] = value;\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + WebgpuAttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length, + const Tensor* seqlen_k) { + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; + const bool has_present_value = output_count > 1 && past_value != nullptr; + const int tile_size = 12; + + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, + {V, ProgramTensorMetadataDependency::TypeAndRank}}); + if (feed_past_value) { + program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_present_value) { + program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, + (parameters.sequence_length_ + tile_size - 1) / tile_size, + parameters.batch_size_ * parameters.num_heads_) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size_)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(parameters.v_hidden_size_)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); + const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; + + const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, total_sequence_length}); + const TensorShape probs_shape(probs_dims); + Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); + ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, + parameters, past_sequence_length, total_sequence_length, seqlen_k)); + + ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); + + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + parameters, past_sequence_length, total_sequence_length, seqlen_k)); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h new file mode 100644 index 0000000000000..03279fffbc3ef --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class TransferBSDToBNSHProgram final : public Program { + public: + TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, + {"batch_offset", ProgramUniformVariableDataType::Uint32}, + {"sequence_offset", ProgramUniformVariableDataType::Uint32}, + {"head_offset", ProgramUniformVariableDataType::Uint32}, + {"bias_offset", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_; +}; + +class AttentionProbsProgram final : public Program { + public: + AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_key_; + bool has_present_key_; + bool has_attention_bias_; + int tile_size_; + int components_; + int n_reps_; + const Tensor* seqlen_k_; + bool past_present_share_buffer_; + bool is_first_prompt_; +}; + +class InPlaceSoftmaxProgram final : public Program { + public: + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + + private: + int work_group_size_; + int components_; + const Tensor* seqlen_k_; + bool is_first_prompt_; +}; + +class VxAttentionScoreProgram final : public Program { + public: + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_value_; + bool has_present_value_; + int tile_size_; + int n_reps_; + const Tensor* seqlen_k_; + bool past_present_share_buffer_; + bool is_first_prompt_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h new file mode 100644 index 0000000000000..b7137ef0aec3a --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention_common.h" + +#include "contrib_ops/cpu/bert/attention_common.h" +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +struct WebgpuAttentionParameters { + WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_(false), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.kv_sequence_length), + past_sequence_length_(parameters.past_sequence_length), + total_sequence_length_(parameters.total_sequence_length), + max_sequence_length_(parameters.max_sequence_length), + input_hidden_size_(parameters.input_hidden_size), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + v_hidden_size_(parameters.v_hidden_size), + v_head_size_(parameters.v_head_size), + num_heads_(parameters.num_heads), + is_unidirectional_(parameters.is_unidirectional), + past_present_share_buffer_(parameters.past_present_share_buffer), + do_rotary_(parameters.do_rotary), + broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0), + broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1), + mask_filter_value_(parameters.mask_filter_value), + scale_(parameters.scale), + mask_type_(parameters.mask_type), + qkv_format_(parameters.qkv_format) { + } + + WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_(true), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.sequence_length), + past_sequence_length_(parameters.seqlen_past_kv_cache), + total_sequence_length_(parameters.total_sequence_length), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + v_hidden_size_(parameters.kv_hidden_size), + v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), + num_heads_(parameters.num_heads), + do_rotary_(parameters.do_rotary), + scale_(parameters.scale), + seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), + seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), + kv_hidden_size_(parameters.kv_hidden_size), + kv_num_heads_(parameters.kv_num_heads), + num_splits_(parameters.num_splits), + rotary_dim_(parameters.rotary_dim), + is_packed_qkv_(parameters.is_packed_qkv), + is_subsequent_prompt_(parameters.is_subsequent_prompt), + is_first_prompt_(parameters.is_first_prompt), + rotary_interleaved_(parameters.rotary_interleaved), + use_smooth_softmax_(parameters.use_smooth_softmax), + softcap_(parameters.softcap), + zeros_count_(parameters.zeros_count), + zero_ptr_(parameters.zero_ptr), + n_reps(parameters.num_heads / parameters.kv_num_heads), + qkv_format_(parameters.qkv_format) { + } + + bool is_gqa_; + int batch_size_ = 0; + int sequence_length_ = 0; + int kv_sequence_length_ = 0; // input sequence length of K or V + int past_sequence_length_ = 0; // sequence length in past state of K or V + int total_sequence_length_ = 0; // total sequence length of K or V + int max_sequence_length_ = 0; // max sequence length from 4D mask + int input_hidden_size_ = 0; // first dimension of weights for input projection + int hidden_size_ = 0; // hidden size of Q or K + int head_size_ = 0; // hidden size per head of Q or K + int v_hidden_size_ = 0; // hidden size of V + int v_head_size_ = 0; // hidden size per head of V + int num_heads_ = 0; + int rotary_embedding_ = 0; + bool is_unidirectional_ = false; + bool past_present_share_buffer_ = false; + bool do_rotary_ = false; + bool broadcast_attn_bias_dim_0_ = false; + bool broadcast_attn_bias_dim_1_ = false; + float mask_filter_value_ = -10000.0f; + float scale_ = 0.0f; + bool use_tf32_ = false; + ; + // The following members are in onnxruntime::contrib::GroupQueryAttentionParameters + // and not in onnxruntime::contrib::AttentionParameters + int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor + int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor + int kv_hidden_size_ = 0; + int kv_num_heads_ = 0; + int num_splits_ = 0; // number of splits for splitkv + int rotary_dim_ = 0; // rotary embedding dimension + int local_window_size_ = 0; + bool kv_share_buffer_ = false; + bool is_packed_qkv_ = false; + bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt_ = false; // indicates whether this is first decoding step + bool rotary_interleaved_ = false; + bool use_smooth_softmax_ = false; + float softcap_ = 0.0; + int zeros_count_ = 0; + ; + int* zero_ptr_ = nullptr; + // Computed values + int n_reps = 1; + AttentionMaskType mask_type_ = MASK_NONE; + AttentionQkvFormat qkv_format_ = UNKNOWN; +}; + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor); + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc new file mode 100644 index 0000000000000..31c8af9b4f922 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "contrib_ops/webgpu/bert/group_query_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::group_query_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + GroupQueryAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .MayInplace(3, 1) + .MayInplace(4, 2) + .InputMemoryType(OrtMemTypeCPUInput, 6), + GroupQueryAttention); + +Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* past_key = context.Input(3); + const Tensor* past_value = context.Input(4); + const Tensor* seqlen_k = context.Input(5); + const Tensor* total_seqlen_tensor = context.Input(6); + const Tensor* cos_cache = context.Input(7); + const Tensor* sin_cache = context.Input(8); + + GroupQueryAttentionParameters params; + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶ms, + num_heads_, + kv_num_heads_, + seqlen_k, + total_seqlen_tensor, + scale_, + softcap_)); + WebgpuAttentionParameters parameters(params); + if (parameters.is_packed_qkv_) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep."); + } + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size_); + output_shape[1] = static_cast(parameters.sequence_length_); + output_shape[2] = static_cast(parameters.hidden_size_); + Tensor* output = context.Output(0, output_shape); + std::vector present_dims{ + parameters.batch_size_, + kv_num_heads_, + parameters.seqlen_present_kv_cache_, + parameters.head_size_}; + std::vector present_kv_shape(present_dims); + Tensor* present_key = context.Output(1, present_kv_shape); + Tensor* present_value = context.Output(2, present_kv_shape); + parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); + + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k); + } + + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, nullptr, 0, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, nullptr, 0, &V)); + return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h new file mode 100644 index 0000000000000..04969dc778927 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class GroupQueryAttention final : public WebGpuKernel { + public: + GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + + use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + } + + int num_heads_; // number of attention heads of Q + int kv_num_heads_; // number of attention heads of K or V + float scale_; // the scaling factor applied before softmax + float softcap_; + bool do_rotary_; // whether or not to use rotary embeddings + bool rotary_interleaved_; + int local_window_size_; + + bool use_smooth_softmax_; + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 5583f296fae42..424556c66bd9d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention_common.h" #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" @@ -25,392 +26,8 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedFloatTypes()), MultiHeadAttention); -Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("qkv_input", ShaderUsage::UseUniform); - const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); - - if (has_bias_) { - shader.AddInput("bias", ShaderUsage::UseUniform); - } - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") - << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" - << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" - << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; - if (has_bias_) { - shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; - } - shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; - if (has_bias_) { - shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; - } else { - shader.MainFunctionBody() << ";\n"; - } - - return Status::OK(); -} - -Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, - int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { - assert(input_tensor->Shape().GetDims().size() == 3); - assert(output_tensor->Shape().GetDims().size() == 4); - - uint32_t data_size = gsl::narrow(output_tensor->Shape().Size()); - const int batch_offset = num_heads * sequence_length * head_size; - const int sequence_offset = num_heads * head_size; - const int head_offset = head_size; - bool has_bias = bias != nullptr; - - TransferBSDToBNSHProgram program{has_bias}; - program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({{data_size}, - {static_cast(batch_offset)}, - {static_cast(sequence_offset)}, - {static_cast(head_offset)}, - {static_cast(bias_offset)}}); - - if (has_bias) { - program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); - } - - return context.RunProgram(program); -}; - -Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (feed_past_key_) { - shader.AddInput("past_key", ShaderUsage::UseUniform); - } - if (has_attention_bias_) { - shader.AddInput("attention_bias", ShaderUsage::UseUniform); - } - - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (has_present_key_) { - shader.AddOutput("present_key", ShaderUsage::UseUniform); - } - - shader.AdditionalImplementation() << "var tileQ: array;\n" - << "var tileK: array;\n" - << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; - - shader.MainFunctionBody() << "// x holds the N and y holds the M\n" - "let headIdx = workgroup_id.z;\n" - "let m = workgroup_id.y * TILE_SIZE;\n" - "let n = workgroup_id.x * TILE_SIZE;\n" - "let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n"; - - if (feed_past_key_ && has_present_key_) { - shader.MainFunctionBody() << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n" - << "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n"; - } else { - shader.MainFunctionBody() << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n"; - } - - if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n"; - } - - shader.MainFunctionBody() << "var value = f32_val_t(0);\n" - "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" - " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" - " }\n" - " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" - " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; - - if (feed_past_key_ && has_present_key_) { - shader.MainFunctionBody() << " if (n + local_id.y < uniforms.past_sequence_length) {\n" - " tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" - " } else {\n" - " tileK[idx] = key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n" - " }\n"; - } else { - shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; - } - - if (has_present_key_) { - shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; - } - - shader.MainFunctionBody() << " }\n" - << " workgroupBarrier();\n" - << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" - << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; - - shader.MainFunctionBody() << "let headOffset = headIdx * uniforms.M * uniforms.N;\n" - << "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n" - << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" - << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; - - shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; - if (has_attention_bias_) { - shader.MainFunctionBody() << " + attention_bias[outputIdx]"; - } - shader.MainFunctionBody() << ";\n" - << "}\n"; - - return Status::OK(); -} - -Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, - const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, - AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { - const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) - : parameters.scale; - - const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; - const bool has_present_key = output_count > 1 && past_key; - const bool has_attention_bias = attention_bias != nullptr; - constexpr int tile_size = 12; - const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); - - AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components}; - program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, - {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (feed_past_key) { - program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); - } - if (has_attention_bias) { - program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); - } - program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); - if (has_present_key) { - program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); - } - - const uint32_t vectorized_head_size = parameters.head_size / components; - program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, - (parameters.sequence_length + tile_size - 1) / tile_size, - parameters.batch_size * parameters.num_heads) - .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size)) - .AddUniformVariables({{static_cast(parameters.sequence_length)}, - {static_cast(vectorized_head_size)}, - {static_cast(total_sequence_length)}, - {static_cast(parameters.num_heads)}, - {static_cast(alpha)}, - {static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}}) - .SetOverridableConstants({{static_cast(tile_size)}}); - - return context.RunProgram(program); -} - -Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AdditionalImplementation() << "var thread_max: array;\n" - << "var thread_sum: array;\n" - << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; - - shader.MainFunctionBody() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n" - << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" - << "}\n" - << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" - << "workgroupBarrier();\n" - << "var max_value = f32(-3.402823e+38f);\n" - << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" - << " max_value = max(thread_max[i], max_value);\n" - << "}\n" - << "var sum_vector = f32_val_t(0);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" - << "}\n" - << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" - << "workgroupBarrier();\n" - << "var sum: f32 = 0;\n" - << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" - << " sum += thread_sum[i]\n;" - << "}\n" - << "if (sum == 0) {\n" - << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << " x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n" - << " }\n" - << "} else {\n" - << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << " var f32input = f32_val_t(x[offset + i]);\n" - << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" - << " }\n" - << "}\n"; - - return Status::OK(); -} - -Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int n, int d) { - const int components = d % 4 == 0 ? 4 : (d % 2 == 0 ? 2 : 1); - int work_group_size = 64; - const int d_comp = d / components; - if (d_comp < work_group_size) { - work_group_size = 32; - } - const int elementsPerThread = (d_comp + work_group_size - 1) / work_group_size; - - InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components}; - program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .SetDispatchGroupSize(n) - .SetWorkgroupSize(work_group_size) - .AddUniformVariables({{static_cast(1.f / static_cast(d))}, - {static_cast(d_comp)}, - {static_cast(elementsPerThread)}}); - - return context.RunProgram(program); -} - -Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (feed_past_value_) { - shader.AddInput("past_value", ShaderUsage::UseUniform); - } - - shader.AddOutput("output", ShaderUsage::UseUniform); - if (has_present_value_) { - shader.AddOutput("present_value", ShaderUsage::UseUniform); - } - - shader.AdditionalImplementation() << "var tileQ: array;\n" - << "var tileK: array;\n"; - - shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" - << "let m = global_id.y;\n" - << "let n = global_id.x;\n" - << "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"; - - if (feed_past_value_ && has_present_value_) { - shader.MainFunctionBody() << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n" - << "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n"; - } else { - shader.MainFunctionBody() << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n"; - } - - if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n"; - } - - shader.MainFunctionBody() << "var value = probs_element_t(0);\n" - << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" - << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" - << " }\n" - << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" - << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; - - if (feed_past_value_ && has_present_value_) { - shader.MainFunctionBody() << " if (w + local_id.y < uniforms.past_sequence_length) {\n" - << " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" - << " } else {\n" - << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" - << " }\n"; - } else { - shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; - } - - if (has_present_value_) { - shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; - } - - shader.MainFunctionBody() << " }\n" - << " workgroupBarrier();\n" - << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" - << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; - - shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" - << "let batchIdx = workgroup_id.z / uniforms.num_heads;\n" - << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" - << "if (m < uniforms.M && n < uniforms.N) {\n" - << " let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + " - << " m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n" - << " output[outputIdx] = value;\n" - << "}\n"; - - return Status::OK(); -} - -Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, - const Tensor* probs, - const Tensor* V, - const Tensor* past_value, - Tensor* output, - Tensor* present_value, - AttentionParameters& parameters, - int past_sequence_length, - int total_sequence_length) { - const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; - const bool has_present_value = output_count > 1 && past_value != nullptr; - constexpr int tile_size = 12; - - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; - program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, - {V, ProgramTensorMetadataDependency::TypeAndRank}}); - if (feed_past_value) { - program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); - } - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); - if (has_present_value) { - program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); - } - - program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size, - (parameters.sequence_length + tile_size - 1) / tile_size, - parameters.batch_size * parameters.num_heads) - .SetWorkgroupSize(tile_size, tile_size) - .AddUniformVariables({{static_cast(parameters.sequence_length)}, - {static_cast(total_sequence_length)}, - {static_cast(parameters.v_head_size)}, - {static_cast(parameters.num_heads)}, - {static_cast(parameters.v_hidden_size)}, - {static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}}) - .SetOverridableConstants({{static_cast(tile_size)}}); - ; - - return context.RunProgram(program); -} - -Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, - const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); - const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; - const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; - - const TensorShapeVector probs_dims({parameters.batch_size, parameters.num_heads, - parameters.sequence_length, total_sequence_length}); - const TensorShape probs_shape(probs_dims); - Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); - ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, - parameters, past_sequence_length, total_sequence_length)); - - ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); - - ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, - parameters, past_sequence_length, total_sequence_length)); - - return Status::OK(); -} - MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) - : WebGpuKernel(info) { - int64_t num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - num_heads_ = static_cast(num_heads); - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - scale_ = info.GetAttrOrDefault("scale", 0.0f); - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + : WebGpuKernel(info), AttentionBase(info, false) { ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); } @@ -434,54 +51,54 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu"); } - AttentionParameters parameters; + AttentionParameters params; ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, - bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters, + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶ms, num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); - + WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); - output_shape[0] = static_cast(parameters.batch_size); - output_shape[1] = static_cast(parameters.sequence_length); - output_shape[2] = static_cast(parameters.v_hidden_size); + output_shape[0] = static_cast(parameters.batch_size_); + output_shape[1] = static_cast(parameters.sequence_length_); + output_shape[2] = static_cast(parameters.v_hidden_size_); Tensor* output = context.Output(0, output_shape); // If optional outputs aren't needed, present_key and present_value will be null std::vector present_dims{ - parameters.batch_size, - parameters.num_heads, - parameters.total_sequence_length, - parameters.head_size, + parameters.batch_size_, + parameters.num_heads_, + parameters.total_sequence_length_, + parameters.head_size_, }; TensorShape present_shape(present_dims); Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); - TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, - parameters.sequence_length, parameters.head_size}); + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); TensorShape q_new_shape(q_new_dims); Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH( - context, parameters.num_heads, parameters.sequence_length, parameters.head_size, query, bias, 0, &Q)); + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, bias, 0, &Q)); - if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, present_value, parameters, context); } - TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, - parameters.kv_sequence_length, parameters.head_size}); + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); TensorShape k_new_shape(k_new_dims); Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - parameters.head_size, key, bias, parameters.hidden_size, &K)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, bias, parameters.hidden_size_, &K)); - TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, - parameters.kv_sequence_length, parameters.v_head_size}); + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); TensorShape v_new_shape(v_new_dims); Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, bias, 2 * parameters.hidden_size_, &V)); // Compute the attention score and apply the score to V return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 36803e3027b4c..d983236422c9e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -7,6 +7,9 @@ #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention.h" + +#include "contrib_ops/cpu/bert/attention_base.h" namespace onnxruntime { namespace contrib { @@ -14,100 +17,10 @@ namespace webgpu { using namespace onnxruntime::webgpu; -class TransferBSDToBNSHProgram final : public Program { - public: - TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, - {"batch_offset", ProgramUniformVariableDataType::Uint32}, - {"sequence_offset", ProgramUniformVariableDataType::Uint32}, - {"head_offset", ProgramUniformVariableDataType::Uint32}, - {"bias_offset", ProgramUniformVariableDataType::Uint32}); - - private: - bool has_bias_; -}; - -class AttentionProbsProgram final : public Program { - public: - AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}, - {"alpha", ProgramUniformVariableDataType::Float32}, - {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); - - WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); - - private: - bool feed_past_key_; - bool has_present_key_; - bool has_attention_bias_; - int tile_size_; - int components_; -}; - -class InPlaceSoftmaxProgram final : public Program { - public: - InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"d_inv", ProgramUniformVariableDataType::Float32}, - {"d_comp", ProgramUniformVariableDataType::Uint32}, - {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); - - private: - int work_group_size_; - int components_; -}; - -class VxAttentionScoreProgram final : public Program { - public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}, - {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, - {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); - - WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); - - private: - bool feed_past_value_; - bool has_present_value_; - int tile_size_; -}; - -class MultiHeadAttention final : public WebGpuKernel { +class MultiHeadAttention final : public WebGpuKernel, public AttentionBase { public: MultiHeadAttention(const OpKernelInfo& info); Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; - - protected: - int num_heads_; - float mask_filter_value_; - float scale_; - bool is_unidirectional_{false}; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 4006006a76ba8..2e7ed5a16a2f0 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -42,7 +42,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo,