-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
995 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/webgpu/tensor/resize.h" | ||
#include "core/providers/webgpu/webgpu_supported_types.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( | ||
Resize, | ||
kOnnxDomain, | ||
10, 10, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1) | ||
.TypeConstraint("T", WebGpuSupportedNumberTypes()), | ||
Resize); | ||
|
||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( | ||
Resize, | ||
kOnnxDomain, | ||
11, 12, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1) | ||
.InputMemoryType(OrtMemTypeCPUInput, 2) | ||
.InputMemoryType(OrtMemTypeCPUInput, 3) | ||
.TypeConstraint("T1", WebGpuSupportedNumberTypes()), | ||
Resize); | ||
|
||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( | ||
Resize, | ||
kOnnxDomain, | ||
13, 17, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1) | ||
.InputMemoryType(OrtMemTypeCPUInput, 2) | ||
.InputMemoryType(OrtMemTypeCPUInput, 3) | ||
.TypeConstraint("T1", WebGpuSupportedNumberTypes()), | ||
Resize); | ||
|
||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( | ||
Resize, | ||
kOnnxDomain, | ||
18, 18, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1) | ||
.InputMemoryType(OrtMemTypeCPUInput, 2) | ||
.InputMemoryType(OrtMemTypeCPUInput, 3) | ||
.TypeConstraint("T1", WebGpuSupportedNumberTypes()), | ||
Resize); | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
Resize, | ||
kOnnxDomain, | ||
19, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1) | ||
.InputMemoryType(OrtMemTypeCPUInput, 2) | ||
.InputMemoryType(OrtMemTypeCPUInput, 3) | ||
.TypeConstraint("T1", WebGpuSupportedNumberTypes()), | ||
Resize); | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/webgpu/tensor/upsample.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
class Resize : public Upsample { | ||
public: | ||
Resize(const OpKernelInfo& info) : Upsample(info) { | ||
Check warning on line 15 in onnxruntime/core/providers/webgpu/tensor/resize.h
|
||
} | ||
|
||
Status ComputeInternal(ComputeContext& context) const override { | ||
return Upsample::ComputeInternal(context); | ||
} | ||
}; | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
603 changes: 603 additions & 0 deletions
603
onnxruntime/core/providers/webgpu/tensor/resize_impl.cc
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/cpu/tensor/upsample.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
class ResizeNearestProgram final : public Program<ResizeNearestProgram> { | ||
public: | ||
ResizeNearestProgram(onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode, | ||
onnxruntime::ResizeNearestMode nearest_mode, | ||
bool extrapolation_enabled, | ||
int32_t rank) : Program{"ResizeNearest2D"}, | ||
coordinate_transform_mode_{coordinate_transform_mode}, | ||
nearest_mode_{nearest_mode}, | ||
extrapolation_enabled_{extrapolation_enabled}, | ||
rank_{rank} {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"roi", ProgramUniformVariableDataType::Float32}, | ||
{"scales", ProgramUniformVariableDataType::Float32}, | ||
{"output_size", ProgramUniformVariableDataType::Uint32}, | ||
{"extrapolation_value", ProgramUniformVariableDataType::Float32}); | ||
|
||
private: | ||
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode_; | ||
onnxruntime::ResizeNearestMode nearest_mode_; | ||
bool extrapolation_enabled_; | ||
int32_t rank_; | ||
}; | ||
|
||
class ResizeBilinearProgram final : public Program<ResizeBilinearProgram> { | ||
public: | ||
ResizeBilinearProgram(onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode, | ||
bool extrapolation_enabled, | ||
int32_t rank) : Program{"ResizeBilinear"}, | ||
coordinate_transform_mode_{coordinate_transform_mode}, | ||
extrapolation_enabled_{extrapolation_enabled}, | ||
rank_{rank} {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"roi", ProgramUniformVariableDataType::Float32}, | ||
{"scales", ProgramUniformVariableDataType::Float32}, | ||
{"output_size", ProgramUniformVariableDataType::Uint32}, | ||
{"extrapolation_value", ProgramUniformVariableDataType::Float32}); | ||
|
||
private: | ||
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode_; | ||
bool extrapolation_enabled_; | ||
int32_t rank_; | ||
}; | ||
|
||
class ResizeTrilinearProgram final : public Program<ResizeTrilinearProgram> { | ||
public: | ||
ResizeTrilinearProgram(onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode, | ||
bool extrapolation_enabled, | ||
int32_t rank) : Program{"ResizeTrilinear"}, | ||
coordinate_transform_mode_{coordinate_transform_mode}, | ||
extrapolation_enabled_{extrapolation_enabled}, | ||
rank_{rank} {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"roi", ProgramUniformVariableDataType::Float32}, | ||
{"scales", ProgramUniformVariableDataType::Float32}, | ||
{"output_size", ProgramUniformVariableDataType::Uint32}, | ||
{"extrapolation_value", ProgramUniformVariableDataType::Float32}); | ||
|
||
private: | ||
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode_; | ||
bool extrapolation_enabled_; | ||
int32_t rank_; | ||
}; | ||
|
||
class ResizeBiCubicProgram final : public Program<ResizeBiCubicProgram> { | ||
public: | ||
ResizeBiCubicProgram(onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode, | ||
bool extrapolation_enabled, | ||
bool exclude_outside, | ||
int32_t rank) : Program{"ResizeBiCubic"}, | ||
coordinate_transform_mode_{coordinate_transform_mode}, | ||
extrapolation_enabled_{extrapolation_enabled}, | ||
exclude_outside_{exclude_outside}, | ||
rank_{rank} {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"roi", ProgramUniformVariableDataType::Float32}, | ||
{"scales", ProgramUniformVariableDataType::Float32}, | ||
{"output_size", ProgramUniformVariableDataType::Uint32}, | ||
{"extrapolation_value", ProgramUniformVariableDataType::Float32}, | ||
{"cubic_coeff_a", ProgramUniformVariableDataType::Float32}); | ||
|
||
private: | ||
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode_; | ||
bool extrapolation_enabled_; | ||
bool exclude_outside_; | ||
int32_t rank_; | ||
}; | ||
|
||
Status ResizeImpl( | ||
ComputeContext& context, | ||
const Tensor* input, | ||
const onnxruntime::UpsampleMode upsample_mode, | ||
gsl::span<const int64_t>& output_dims, | ||
gsl::span<const float> roi, | ||
gsl::span<const float> scales, | ||
bool extrapolation_enabled, | ||
const float extrapolation_value, | ||
float cubic_coeff_a, | ||
bool exclude_outside, | ||
onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode, | ||
onnxruntime::ResizeNearestMode nearest_mode); | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/webgpu/tensor/resize_impl.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/tensor/upsample.h" | ||
#include "core/providers/webgpu/webgpu_supported_types.h" | ||
|
||
using namespace onnxruntime::common; | ||
Check warning on line 9 in onnxruntime/core/providers/webgpu/tensor/upsample.cc
|
||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
Status Upsample::BaseCompute(ComputeContext& context, | ||
gsl::span<const float> roi, | ||
gsl::span<const float> scales, | ||
gsl::span<const int64_t> output_dims) const { | ||
const auto* X = context.Input(0); | ||
auto dims = X->Shape().GetDims(); | ||
ORT_ENFORCE(output_dims.size() == dims.size(), "Rank of input and output tensor should be same."); | ||
|
||
if (dims.size() == 0) { | ||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, | ||
is_resize_ ? "Resize: input tensor cannot be scalar." | ||
: "Upsample: input tensor cannot be scalar."); | ||
} | ||
if (dims.size() != scales.size()) { | ||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, | ||
is_resize_ ? "Resize: input tensor's dimension does not match the scales." | ||
: "Upsample: input tensor's dimension does not match the scales."); | ||
} | ||
if (roi.size() != 2 * dims.size()) { | ||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, | ||
"Resize: size of roi array should be 2 * N where N is the rank of input tensor X."); | ||
} | ||
|
||
Tensor* Y = context.Output(0, output_dims); | ||
// Return early if the output tensor is going to be of size 0 | ||
if (Y->Shape().Size() == 0) { | ||
return Status::OK(); | ||
} | ||
|
||
if (is_resize_) { | ||
if (!antialias_) { | ||
return ResizeImpl(context, X, mode_, output_dims, roi, scales, use_extrapolation_, extrapolation_value_, | ||
cubic_coeff_a_, exclude_outside_, coordinate_transform_mode_, nearest_mode_); | ||
} else { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, | ||
"The antialias attribute of Resize operator is NOT implemented."); | ||
} | ||
} else { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Upsample operator is NOT implemented."); | ||
} | ||
} | ||
|
||
Status Upsample::ComputeInternal(ComputeContext& context) const { | ||
const auto* X = context.Input(0); | ||
auto input_dims = X->Shape().GetDims(); | ||
TensorShapeVector output_dims(input_dims.size()); | ||
|
||
// Get roi data | ||
// Initialize the roi array to all zeros as this will be the most common case | ||
// Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize | ||
// for all other cases we need a 0 initialized roi array | ||
InlinedVector<float> roi_array(roi_); | ||
|
||
if (!roi_cached_) { | ||
bool use_default_roi = true; | ||
if (need_roi_input_) { | ||
ORT_ENFORCE(roi_input_idx_ > 0, "Invalid roi input index."); | ||
const auto* roi = context.Input(roi_input_idx_); | ||
if (roi != nullptr) { | ||
ParseRoiData(roi, roi_array); | ||
use_default_roi = false; | ||
} | ||
} | ||
if (use_default_roi) { | ||
// default roi includes ensures all the values in that axis are included in the roi | ||
// normalized roi is thus : [start, end] = [0, 1] | ||
size_t input_rank = input_dims.size(); | ||
roi_array.resize(input_rank * 2); | ||
for (size_t i = 0; i < input_rank; ++i) { | ||
roi_array[i] = 0; | ||
roi_array[i + input_rank] = 1; | ||
} | ||
} | ||
} | ||
|
||
ComputeROIWithAxes(roi_array, input_dims.size()); | ||
|
||
InlinedVector<float> scales_array(input_dims.size()); | ||
// opset < 10 | ||
if (OpKernel::Node().InputDefs().size() == 1) { | ||
scales_array = scales_; | ||
// Compute output shape from scales attributes and input dims | ||
ComputeOutputShape(scales_array, input_dims, output_dims); | ||
return BaseCompute(context, roi_array, scales_array, output_dims); | ||
} | ||
|
||
const auto* scales = context.Input(scales_input_idx_); | ||
const auto* sizes = context.Input(sizes_input_idx_); | ||
|
||
// This is when scales are obtained and cached from a constant initializer | ||
if (scales_cached_) { | ||
ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input."); | ||
scales_array = scales_; | ||
// Compute output shape from scales and input dims | ||
ComputeOutputShape(scales_array, input_dims, output_dims); | ||
return BaseCompute(context, roi_array, scales_array, output_dims); | ||
} | ||
|
||
// Scales and sizes are input to the node | ||
if (scales != nullptr && scales->Shape().Size() != 0) { | ||
// use scales input data | ||
ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); | ||
ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size())); | ||
|
||
// Compute output shape from scales and input dims | ||
ComputeOutputShape(scales_array, input_dims, output_dims); | ||
} else { | ||
// When sizes input is available directly populate it into the output_dims array. | ||
ORT_ENFORCE(sizes != nullptr && sizes->Shape().Size() != 0, | ||
"Either scales or sizes MUST be provided as input."); | ||
ORT_RETURN_IF_ERROR(ParseSizesData(sizes, output_dims, input_dims)); | ||
ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array)); | ||
} | ||
|
||
return BaseCompute(context, roi_array, scales_array, output_dims); | ||
} | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/cpu/tensor/upsample.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
class Upsample : public UpsampleBase, public WebGpuKernel { | ||
public: | ||
explicit Upsample(const OpKernelInfo& info) : UpsampleBase(info), WebGpuKernel(info) {}; | ||
Check warning on line 15 in onnxruntime/core/providers/webgpu/tensor/upsample.h
|
||
|
||
Status ComputeInternal(ComputeContext& context) const override; | ||
Status BaseCompute(ComputeContext& context, gsl::span<const float> roi, gsl::span<const float> scales, | ||
gsl::span<const int64_t> output_dims) const; | ||
}; | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
Oops, something went wrong.