Skip to content

Commit

Permalink
[webgpu] support resize operator
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao committed Feb 21, 2025
1 parent 754ee21 commit 493ceb2
Show file tree
Hide file tree
Showing 9 changed files with 995 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ bool ConvertNodeLayout(const api::NodeRef& node) {
}
#endif

// NHWC for Resize operator is not implemented on kWebGpuExecutionProvider
#if defined(USE_WEBGPU)
if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) {
if (node.OpType() == "Resize") {
return false;
}
}
#endif

#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
if (node.GetExecutionProviderType() == kCudaExecutionProvider) {
if (layout_sensitive_ops.count(node.OpType())) {
Expand Down
69 changes: 69 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/resize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/tensor/resize.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Resize,
kOnnxDomain,
10, 10,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Resize);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Resize,
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.TypeConstraint("T1", WebGpuSupportedNumberTypes()),
Resize);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Resize,
kOnnxDomain,
13, 17,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.TypeConstraint("T1", WebGpuSupportedNumberTypes()),
Resize);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Resize,
kOnnxDomain,
18, 18,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.TypeConstraint("T1", WebGpuSupportedNumberTypes()),
Resize);

ONNX_OPERATOR_KERNEL_EX(
Resize,
kOnnxDomain,
19,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.TypeConstraint("T1", WebGpuSupportedNumberTypes()),
Resize);

} // namespace webgpu
} // namespace onnxruntime
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/resize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/tensor/upsample.h"

namespace onnxruntime {
namespace webgpu {

class Resize : public Upsample {
public:
Resize(const OpKernelInfo& info) : Upsample(info) {

Check warning on line 15 in onnxruntime/core/providers/webgpu/tensor/resize.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/resize.h:15: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
}

Status ComputeInternal(ComputeContext& context) const override {
return Upsample::ComputeInternal(context);
}
};

} // namespace webgpu
} // namespace onnxruntime
603 changes: 603 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/resize_impl.cc

Large diffs are not rendered by default.

123 changes: 123 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/resize_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/cpu/tensor/upsample.h"

namespace onnxruntime {
namespace webgpu {

class ResizeNearestProgram final : public Program<ResizeNearestProgram> {
public:
ResizeNearestProgram(onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode,
onnxruntime::ResizeNearestMode nearest_mode,
bool extrapolation_enabled,
int32_t rank) : Program{"ResizeNearest2D"},
coordinate_transform_mode_{coordinate_transform_mode},
nearest_mode_{nearest_mode},
extrapolation_enabled_{extrapolation_enabled},
rank_{rank} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"roi", ProgramUniformVariableDataType::Float32},
{"scales", ProgramUniformVariableDataType::Float32},
{"output_size", ProgramUniformVariableDataType::Uint32},
{"extrapolation_value", ProgramUniformVariableDataType::Float32});

private:
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode_;
onnxruntime::ResizeNearestMode nearest_mode_;
bool extrapolation_enabled_;
int32_t rank_;
};

class ResizeBilinearProgram final : public Program<ResizeBilinearProgram> {
public:
ResizeBilinearProgram(onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode,
bool extrapolation_enabled,
int32_t rank) : Program{"ResizeBilinear"},
coordinate_transform_mode_{coordinate_transform_mode},
extrapolation_enabled_{extrapolation_enabled},
rank_{rank} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"roi", ProgramUniformVariableDataType::Float32},
{"scales", ProgramUniformVariableDataType::Float32},
{"output_size", ProgramUniformVariableDataType::Uint32},
{"extrapolation_value", ProgramUniformVariableDataType::Float32});

private:
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode_;
bool extrapolation_enabled_;
int32_t rank_;
};

class ResizeTrilinearProgram final : public Program<ResizeTrilinearProgram> {
public:
ResizeTrilinearProgram(onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode,
bool extrapolation_enabled,
int32_t rank) : Program{"ResizeTrilinear"},
coordinate_transform_mode_{coordinate_transform_mode},
extrapolation_enabled_{extrapolation_enabled},
rank_{rank} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"roi", ProgramUniformVariableDataType::Float32},
{"scales", ProgramUniformVariableDataType::Float32},
{"output_size", ProgramUniformVariableDataType::Uint32},
{"extrapolation_value", ProgramUniformVariableDataType::Float32});

private:
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode_;
bool extrapolation_enabled_;
int32_t rank_;
};

class ResizeBiCubicProgram final : public Program<ResizeBiCubicProgram> {
public:
ResizeBiCubicProgram(onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode,
bool extrapolation_enabled,
bool exclude_outside,
int32_t rank) : Program{"ResizeBiCubic"},
coordinate_transform_mode_{coordinate_transform_mode},
extrapolation_enabled_{extrapolation_enabled},
exclude_outside_{exclude_outside},
rank_{rank} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"roi", ProgramUniformVariableDataType::Float32},
{"scales", ProgramUniformVariableDataType::Float32},
{"output_size", ProgramUniformVariableDataType::Uint32},
{"extrapolation_value", ProgramUniformVariableDataType::Float32},
{"cubic_coeff_a", ProgramUniformVariableDataType::Float32});

private:
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode_;
bool extrapolation_enabled_;
bool exclude_outside_;
int32_t rank_;
};

Status ResizeImpl(
ComputeContext& context,
const Tensor* input,
const onnxruntime::UpsampleMode upsample_mode,
gsl::span<const int64_t>& output_dims,
gsl::span<const float> roi,
gsl::span<const float> scales,
bool extrapolation_enabled,
const float extrapolation_value,
float cubic_coeff_a,
bool exclude_outside,
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode,
onnxruntime::ResizeNearestMode nearest_mode);

} // namespace webgpu
} // namespace onnxruntime
132 changes: 132 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/tensor/resize_impl.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/tensor/upsample.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

using namespace onnxruntime::common;

Check warning on line 9 in onnxruntime/core/providers/webgpu/tensor/upsample.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/webgpu/tensor/upsample.cc:9: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace webgpu {

Status Upsample::BaseCompute(ComputeContext& context,
gsl::span<const float> roi,
gsl::span<const float> scales,
gsl::span<const int64_t> output_dims) const {
const auto* X = context.Input(0);
auto dims = X->Shape().GetDims();
ORT_ENFORCE(output_dims.size() == dims.size(), "Rank of input and output tensor should be same.");

if (dims.size() == 0) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
is_resize_ ? "Resize: input tensor cannot be scalar."
: "Upsample: input tensor cannot be scalar.");
}
if (dims.size() != scales.size()) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
is_resize_ ? "Resize: input tensor's dimension does not match the scales."
: "Upsample: input tensor's dimension does not match the scales.");
}
if (roi.size() != 2 * dims.size()) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
"Resize: size of roi array should be 2 * N where N is the rank of input tensor X.");
}

Tensor* Y = context.Output(0, output_dims);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
return Status::OK();
}

if (is_resize_) {
if (!antialias_) {
return ResizeImpl(context, X, mode_, output_dims, roi, scales, use_extrapolation_, extrapolation_value_,
cubic_coeff_a_, exclude_outside_, coordinate_transform_mode_, nearest_mode_);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"The antialias attribute of Resize operator is NOT implemented.");
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Upsample operator is NOT implemented.");
}
}

Status Upsample::ComputeInternal(ComputeContext& context) const {
const auto* X = context.Input(0);
auto input_dims = X->Shape().GetDims();
TensorShapeVector output_dims(input_dims.size());

// Get roi data
// Initialize the roi array to all zeros as this will be the most common case
// Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize
// for all other cases we need a 0 initialized roi array
InlinedVector<float> roi_array(roi_);

if (!roi_cached_) {
bool use_default_roi = true;
if (need_roi_input_) {
ORT_ENFORCE(roi_input_idx_ > 0, "Invalid roi input index.");
const auto* roi = context.Input(roi_input_idx_);
if (roi != nullptr) {
ParseRoiData(roi, roi_array);
use_default_roi = false;
}
}
if (use_default_roi) {
// default roi includes ensures all the values in that axis are included in the roi
// normalized roi is thus : [start, end] = [0, 1]
size_t input_rank = input_dims.size();
roi_array.resize(input_rank * 2);
for (size_t i = 0; i < input_rank; ++i) {
roi_array[i] = 0;
roi_array[i + input_rank] = 1;
}
}
}

ComputeROIWithAxes(roi_array, input_dims.size());

InlinedVector<float> scales_array(input_dims.size());
// opset < 10
if (OpKernel::Node().InputDefs().size() == 1) {
scales_array = scales_;
// Compute output shape from scales attributes and input dims
ComputeOutputShape(scales_array, input_dims, output_dims);
return BaseCompute(context, roi_array, scales_array, output_dims);
}

const auto* scales = context.Input(scales_input_idx_);
const auto* sizes = context.Input(sizes_input_idx_);

// This is when scales are obtained and cached from a constant initializer
if (scales_cached_) {
ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input.");
scales_array = scales_;
// Compute output shape from scales and input dims
ComputeOutputShape(scales_array, input_dims, output_dims);
return BaseCompute(context, roi_array, scales_array, output_dims);
}

// Scales and sizes are input to the node
if (scales != nullptr && scales->Shape().Size() != 0) {
// use scales input data
ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input.");
ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size()));

// Compute output shape from scales and input dims
ComputeOutputShape(scales_array, input_dims, output_dims);
} else {
// When sizes input is available directly populate it into the output_dims array.
ORT_ENFORCE(sizes != nullptr && sizes->Shape().Size() != 0,
"Either scales or sizes MUST be provided as input.");
ORT_RETURN_IF_ERROR(ParseSizesData(sizes, output_dims, input_dims));
ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array));
}

return BaseCompute(context, roi_array, scales_array, output_dims);
}

} // namespace webgpu
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/upsample.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/cpu/tensor/upsample.h"

namespace onnxruntime {
namespace webgpu {

class Upsample : public UpsampleBase, public WebGpuKernel {
public:
explicit Upsample(const OpKernelInfo& info) : UpsampleBase(info), WebGpuKernel(info) {};

Check warning on line 15 in onnxruntime/core/providers/webgpu/tensor/upsample.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/upsample.h:15: You don't need a ; after a } [readability/braces] [4]

Status ComputeInternal(ComputeContext& context) const override;
Status BaseCompute(ComputeContext& context, gsl::span<const float> roi, gsl::span<const float> scales,
gsl::span<const int64_t> output_dims) const;
};

} // namespace webgpu
} // namespace onnxruntime
Loading

0 comments on commit 493ceb2

Please sign in to comment.