From 7482606092f92389c0addc27e9e9de3e61760df9 Mon Sep 17 00:00:00 2001 From: Jie Chen Date: Sat, 15 Feb 2025 18:18:07 +0800 Subject: [PATCH] Add MaxPool and AveragePool --- onnxruntime/core/providers/webgpu/nn/pool.cc | 252 ++++++++++++++++++ onnxruntime/core/providers/webgpu/nn/pool.h | 53 ++++ .../webgpu/webgpu_execution_provider.cc | 46 ++-- 3 files changed, 328 insertions(+), 23 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/nn/pool.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/pool.h diff --git a/onnxruntime/core/providers/webgpu/nn/pool.cc b/onnxruntime/core/providers/webgpu/nn/pool.cc new file mode 100644 index 0000000000000..0baedc7a2fe0d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/pool.cc @@ -0,0 +1,252 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/pool.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { + +std::vector NarrowToU32(const TensorShapeVector& shape) { + std::vector result; + result.reserve(shape.size()); + for (auto dim : shape) { + result.push_back(gsl::narrow_cast(dim)); + } + return result; +} + +} // namespace + +#define POOLING_KERNEL(op_name, domain, is_nhwc, pool_type, since_version) \ + ONNX_OPERATOR_KERNEL_EX(op_name, domain, since_version, kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Pool); + +#define POOLING_KERNEL_VERSIONED(op_name, domain, is_nhwc, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX(op_name, domain, since_version, end_version, kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Pool); + +#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_nhwc, pool_type, since_version) \ + ONNX_OPERATOR_KERNEL_EX(op_name, domain, since_version, kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedFloatTypes()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); + +#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_nhwc, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX(op_name, domain, since_version, end_version, kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedFloatTypes()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); + +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9) +POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 7, 9) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10) +POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 10, 10) +POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11) +POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11) +POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1) +POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1) + +POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, MaxPool<1>, 1, 7) +POOLING_KERNEL_VERSIONED(MaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1, 7) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 8, 9) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 8, 9) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 10, 10) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 10, 10) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 11, 11) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 11, 11) +POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 12) +POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 12) +POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, MaxPool<1>, 1) +POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1) + +Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + // Declare and initialize the variables needed. + std::string var_decl_code; + // Process each element in the pooling window. + std::string sampling_code; + // Calculate the output value for each pooling window. + std::string downsampling_code; + if (is_max_pool_) { + std::string f16_min = "f16(-65504)"; + + std::stringstream f32_min_ss; + f32_min_ss << "f32(" << std::numeric_limits::lowest() << ")"; + std::string f32_min = f32_min_ss.str(); + + std::stringstream var_decl_ss; + var_decl_ss << " var value = " << (is_float16_ ? f16_min : f32_min) << ";\n"; + var_decl_code = var_decl_ss.str(); + + sampling_code = " value = max(value, x_val);\n"; + } else { + std::stringstream var_decl_ss; + var_decl_ss << " var value = " << (is_float16_ ? "f16(0)" : "f32(0)") << ";\n"; + if (!count_include_pad_) { + var_decl_ss << " var count = u32(0);\n"; + } else { + var_decl_ss << " var count = uniforms.kernel_size;\n"; + } + var_decl_code = var_decl_ss.str(); + + std::stringstream sampling_ss; + sampling_ss << " value += x_val;\n"; + if (!count_include_pad_) { + sampling_ss << " count++;\n"; + } + sampling_code = sampling_ss.str(); + + downsampling_code = " value /= f32(count);\n"; + } + + const auto kernal_rank = kernel_shape_.size(); + const auto pads_rank = kernel_shape_.size() * 2; + // The dimension index for H or D1 + const auto data_dim_begin = is_nhwc_ ? 1 : 2; + // The dimension index after W or Dn + auto data_dim_end = input.Rank(); + data_dim_end = is_nhwc_ ? data_dim_end - 1 : data_dim_end; + + std::stringstream d_idx_ss; + d_idx_ss << "j - " << data_dim_begin; + std::string d_idx_code = d_idx_ss.str(); + + // clang-format off + auto& body = shader.MainFunctionBody(); + body << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let y_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " var x_indices = y_indices;\n" + << " var k_indices: array;\n" + << var_decl_code + << " for (var i: u32 = 0; i < uniforms.kernel_size; i++) {\n" + << " var offset = i;\n" + // ---- Compute offset to indices in pooling window. + << " for (var j = 0; j < " << kernal_rank << "; j++) {\n" + << " k_indices[j] = offset / " << GetElementAt("uniforms.kernel_strides", "j", kernal_rank) << ";\n" + << " offset = offset % " << GetElementAt("uniforms.kernel_strides", "j", kernal_rank) << ";\n" + << " }\n" + // ---- Apply dilations in pooling window. + << " for (var j = 0; j < " << kernal_rank << "; j++) {\n" + << " k_indices[j] *= " << GetElementAt("uniforms.dilations", "j", kernal_rank) << ";\n" + << " }\n" + << " var is_pad = false;\n" + // ---- Compute x_indices in each data dimension + << " for (var j = " << data_dim_begin << "; j < " << data_dim_end << "; j++) {\n" + << " x_indices[j] = y_indices[j] * " << GetElementAt("uniforms.strides", d_idx_code, kernal_rank) << ";\n" + << " x_indices[j] += k_indices[" << d_idx_code << "];\n" + << " x_indices[j] -= " << GetElementAt("uniforms.pads", d_idx_code, pads_rank) << ";\n" + << " let j_dim_len = " << input.IndicesGet("uniforms.input_shape", "j") << ";\n" + // ------ Check if x_indices[j] is out of bounds to handle padding. + << " if (x_indices[j] < 0 || x_indices[j] >= j_dim_len) {\n" + << " is_pad = true;\n" + << " break;\n" + << " }\n" + << " }\n" + << " if (!is_pad) {\n" + << " let x_val = " << input.GetByIndices("x_indices") << ";\n" + << sampling_code + << " }\n" + << " }\n" + << downsampling_code + << " " << output.SetByOffset("global_idx", "value") << ";\n"; + // clang-format on + + return Status::OK(); +} + +template +Status Pool::ComputeInternal(ComputeContext& context) const { + // TODO: support 'ceil' mode. + ORT_RETURN_IF_NOT(pool_attrs_.ceil_mode == 0, "Using ceil ceil_mode is not supported yet."); + // TODO: support 'column major' storage_order. + ORT_RETURN_IF_NOT(pool_attrs_.storage_order == 0, "Using column major storage_order is not supported yet."); + + // TODO: support 'Indices' output. + ORT_RETURN_IF_NOT(context.OutputCount() == 1, "The Indices output is not supported yet."); + + const auto* X = context.Input(0); + const TensorShape& x_shape = X->Shape(); + const auto input_shape = x_shape.AsShapeVector(); + ORT_RETURN_IF_NOT(input_shape.size() >= 3, "Input dimension cannot be less than 3."); + + auto kernel_shape = pool_attrs_.kernel_shape; + auto strides = pool_attrs_.strides; + auto pads = pool_attrs_.pads; + auto dilations = pool_attrs_.dilations; + // Global pooling is equivalent to having the kernel size equal to the spatial dimension of input tensor. + if (pool_attrs_.global_pooling) { + if (!is_nhwc) { + kernel_shape.assign(input_shape.begin() + 2, input_shape.end()); + } else { + kernel_shape.assign(input_shape.begin() + 1, input_shape.end() - 1); + } + // No padding. + pads.assign(2 * kernel_shape.size(), 0); + // Stride of 1. + strides.assign(kernel_shape.size(), 1); + // Dilation of 1. + dilations.assign(kernel_shape.size(), 1); + } + + // Calculate the output shape + const auto out_channel = x_shape[is_nhwc ? input_shape.size() - 1 : 1]; + const auto output_shape = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, is_nhwc); + Tensor* Y = context.Output(0, output_shape); + + std::vector kernel_strides(kernel_shape.size()); + ORT_ENFORCE(kernel_shape.size() > 0, "kernel_shape must have at least one element."); + // Calculate the kernel element strides for each dimension in reverse order. For example: + // kernel_shape = [3, 2], kernel_strides = [2, 1] + // kernel_shape = [2, 3, 2], kernel_strides = [6, 2, 1] + for (size_t i = kernel_shape.size(); i > 0; --i) { + if (i == kernel_shape.size()) { + kernel_strides[i - 1] = 1; + } else { + kernel_strides[i - 1] = kernel_strides[i] * gsl::narrow_cast(kernel_shape[i]); + } + } + + bool is_max_pool = false; + if constexpr (PoolType::type == onnxruntime::PoolType::kMaxPool) { + is_max_pool = true; + } else if constexpr (PoolType::type != onnxruntime::PoolType::kAveragePool) { + ORT_NOT_IMPLEMENTED("Unsupported PoolType."); + } + bool is_float16 = X->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + bool count_include_pad = pool_attrs_.count_include_pad; + PoolProgram program{is_max_pool, is_nhwc, kernel_shape, is_float16, count_include_pad}; + + // Number of elements + uint32_t output_size = gsl::narrow_cast(Y->Shape().Size()); + uint32_t kernel_size = gsl::narrow_cast(TensorShape{kernel_shape}.Size()); + + const auto pads_u32 = NarrowToU32(pads); + const auto strides_u32 = NarrowToU32(strides); + const auto dilations_u32 = NarrowToU32(dilations); + + program.CacheHint(kernel_shape.size(), is_max_pool, is_nhwc, is_float16, count_include_pad) + .AddInputs({{X, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{Y}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({output_size, kernel_size, + gsl::span(kernel_strides.data(), kernel_strides.size()), + gsl::span(pads_u32.data(), pads_u32.size()), + gsl::span(strides_u32.data(), strides_u32.size()), + gsl::span(dilations_u32.data(), dilations_u32.size())}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/pool.h b/onnxruntime/core/providers/webgpu/nn/pool.h new file mode 100644 index 0000000000000..99a989cb9ee28 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/pool.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/common.h" +#include "core/providers/cpu/nn/pool_base.h" + +namespace onnxruntime { +namespace webgpu { + +class PoolProgram final : public Program { + public: + PoolProgram(bool is_max_pool, bool is_nhwc, const TensorShapeVector& kernel_shape, bool is_float16, + bool count_include_pad) + : Program{"Pool"}, + is_max_pool_{is_max_pool}, + is_nhwc_{is_nhwc}, + kernel_shape_{kernel_shape}, + is_float16_{is_float16}, + count_include_pad_{count_include_pad} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"kernel_size", ProgramUniformVariableDataType::Uint32}, + {"kernel_strides", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}); + + private: + // Whether it is max pool or average pool. + const bool is_max_pool_; + + const bool is_nhwc_; + const TensorShapeVector kernel_shape_; + const bool is_float16_; + const bool count_include_pad_; +}; + +template +class Pool : public WebGpuKernel, public PoolBase { + public: + Pool(const OpKernelInfo& info) : WebGpuKernel(info), PoolBase(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index f517ef9d36458..171d543b6d3d0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -586,29 +586,29 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo,