-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NHWC DepthToSpace U8 and its transformation #23784
base: main
Are you sure you want to change the base?
Changes from all commits
2a4a88b
91c1139
35029a2
e9acabe
cd572f7
9652b1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3683,6 +3683,82 @@ | |||||||||||||||||||||
} | ||||||||||||||||||||||
}); | ||||||||||||||||||||||
|
||||||||||||||||||||||
static const char* DepthToSpace_ver1_doc = R"DOC( | ||||||||||||||||||||||
It is similar to DepthToSpace (https://github.com/onnx/onnx/blob/main/docs/Operators.md#DepthToSpace) with differences: | ||||||||||||||||||||||
1. It has additional attribute channels_last. | ||||||||||||||||||||||
2. Input and output data type is uint8. | ||||||||||||||||||||||
)DOC"; | ||||||||||||||||||||||
|
||||||||||||||||||||||
ONNX_CONTRIB_OPERATOR_SCHEMA(DepthToSpace) | ||||||||||||||||||||||
.SetDomain(kMSDomain) | ||||||||||||||||||||||
.SinceVersion(1) | ||||||||||||||||||||||
.SetDoc(DepthToSpace_ver1_doc) | ||||||||||||||||||||||
.Attr("blocksize", "Blocks of [blocksize, blocksize] are moved.", AttributeProto::INT) | ||||||||||||||||||||||
.Attr( | ||||||||||||||||||||||
"channels_last", | ||||||||||||||||||||||
"1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 0.", | ||||||||||||||||||||||
AttributeProto::INT, | ||||||||||||||||||||||
static_cast<int64_t>(0)) | ||||||||||||||||||||||
.Attr( | ||||||||||||||||||||||
"mode", | ||||||||||||||||||||||
"DCR (default) for depth-column-row order re-arrangement. Use CRD for column-row-depth order.", | ||||||||||||||||||||||
AttributeProto::STRING, | ||||||||||||||||||||||
std::string("DCR")) | ||||||||||||||||||||||
Check warning on line 3706 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc
|
||||||||||||||||||||||
.Input( | ||||||||||||||||||||||
0, | ||||||||||||||||||||||
"input", | ||||||||||||||||||||||
"Input data tensor. Dimensions are [N,H,W,C] when channels_last is 1 or [N,C,H,W] otherwise, where N is the" | ||||||||||||||||||||||
"batch axis, C is the channel or depth, H is the height and W is the width.", | ||||||||||||||||||||||
"T", | ||||||||||||||||||||||
OpSchema::Single, | ||||||||||||||||||||||
true, | ||||||||||||||||||||||
1, | ||||||||||||||||||||||
OpSchema::Differentiable) | ||||||||||||||||||||||
.Output( | ||||||||||||||||||||||
0, | ||||||||||||||||||||||
"output", | ||||||||||||||||||||||
"Output data tensor. Dimensions are [N, H * blocksize, W * blocksize, C/(blocksize * blocksize)] when" | ||||||||||||||||||||||
"channels_last is 1 or [N, C/(blocksize * blocksize), H * blocksize, W * blocksize] otherwise.", | ||||||||||||||||||||||
"T", | ||||||||||||||||||||||
OpSchema::Single, | ||||||||||||||||||||||
true, | ||||||||||||||||||||||
1, | ||||||||||||||||||||||
OpSchema::Differentiable) | ||||||||||||||||||||||
.TypeConstraint("T", {"tensor(uint8)"}, "") | ||||||||||||||||||||||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) { | ||||||||||||||||||||||
propagateElemTypeFromInputToOutput(ctx, 0, 0); | ||||||||||||||||||||||
auto blocksize = getAttribute(ctx, "blocksize", 0); | ||||||||||||||||||||||
if (blocksize <= 0) { | ||||||||||||||||||||||
fail_shape_inference("Blocksize must be positive"); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
if (hasInputShape(ctx, 0)) { | ||||||||||||||||||||||
auto& input_shape = getInputShape(ctx, 0); | ||||||||||||||||||||||
if (input_shape.dim_size() == 4) { | ||||||||||||||||||||||
// TODO: Clarify what behavior should be if C is not a multiple of | ||||||||||||||||||||||
Check warning on line 3737 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc
|
||||||||||||||||||||||
// blocksize*blocksize. | ||||||||||||||||||||||
if (getAttribute(ctx, "channels_last", 0) == 0) { | ||||||||||||||||||||||
updateOutputShape( | ||||||||||||||||||||||
ctx, | ||||||||||||||||||||||
0, | ||||||||||||||||||||||
{input_shape.dim(0), | ||||||||||||||||||||||
input_shape.dim(1) / (blocksize * blocksize), | ||||||||||||||||||||||
input_shape.dim(2) * blocksize, | ||||||||||||||||||||||
input_shape.dim(3) * blocksize}); | ||||||||||||||||||||||
} else { // channels_last | ||||||||||||||||||||||
updateOutputShape( | ||||||||||||||||||||||
ctx, | ||||||||||||||||||||||
0, | ||||||||||||||||||||||
{input_shape.dim(0), | ||||||||||||||||||||||
input_shape.dim(1) * blocksize, | ||||||||||||||||||||||
input_shape.dim(2) * blocksize, | ||||||||||||||||||||||
input_shape.dim(3) / (blocksize * blocksize)}); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
Comment on lines
+3751
to
+3755
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
} else { | ||||||||||||||||||||||
fail_shape_inference("Input tensor must be 4-dimensional"); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
}); | ||||||||||||||||||||||
|
||||||||||||||||||||||
#ifdef ENABLE_ATEN | ||||||||||||||||||||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) | ||||||||||||||||||||||
.SetDomain(kPytorchAtenDomain) | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,6 +7,7 @@ | |||||||||||||||||||||||
#endif | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
#include "core/providers/cpu/tensor/space_depth_ops.h" | ||||||||||||||||||||||||
#include "core/providers/cpu/tensor/transpose.h" | ||||||||||||||||||||||||
#include "core/common/eigen_common_wrapper.h" | ||||||||||||||||||||||||
#include <array> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
@@ -56,6 +57,18 @@ | |||||||||||||||||||||||
DataTypeImpl::GetTensorType<uint8_t>()}), | ||||||||||||||||||||||||
DepthToSpace); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
namespace contrib { | ||||||||||||||||||||||||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||||||||||||||||||||||||
DepthToSpace, | ||||||||||||||||||||||||
kMSDomain, | ||||||||||||||||||||||||
1, | ||||||||||||||||||||||||
uint8_t, | ||||||||||||||||||||||||
kCpuExecutionProvider, | ||||||||||||||||||||||||
KernelDefBuilder() | ||||||||||||||||||||||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<uint8_t>()), | ||||||||||||||||||||||||
DepthToSpace); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
Check warning on line 70 in onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
// intermediate tensor shapes are: | ||||||||||||||||||||||||
// (batch, blocksize, blocksize, input_depth / (blocksize * blocksize), input_height, input_width) for DepthToSpace | ||||||||||||||||||||||||
// (batch, input_depth, input_height / blocksize, blocksize, input_width / blocksize, blocksize) for SpaceToDepth | ||||||||||||||||||||||||
|
@@ -157,6 +170,47 @@ | |||||||||||||||||||||||
int64_t output_height = -1; | ||||||||||||||||||||||||
int64_t output_width = -1; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (is_nhwc_) { | ||||||||||||||||||||||||
ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc<true>(input, | ||||||||||||||||||||||||
batch, | ||||||||||||||||||||||||
input_depth, input_height, input_width, | ||||||||||||||||||||||||
output_depth, output_height, output_width, | ||||||||||||||||||||||||
false)); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Comment on lines
+174
to
+179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
Tensor& output = *context->Output(0, {batch, output_height, output_width, output_depth}); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
int64_t virtual_input_depth = input_depth / blocksize_ / blocksize_; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
TensorShape virtual_input_shape; | ||||||||||||||||||||||||
if (is_dcr_) { | ||||||||||||||||||||||||
virtual_input_shape = TensorShape{batch, input_height, input_width, | ||||||||||||||||||||||||
blocksize_, blocksize_, virtual_input_depth}; | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
virtual_input_shape = TensorShape{batch, input_height, input_width, | ||||||||||||||||||||||||
virtual_input_depth, blocksize_, blocksize_}; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
TensorShape virtual_output_shape = TensorShape{batch, | ||||||||||||||||||||||||
input_height, blocksize_, | ||||||||||||||||||||||||
input_width, blocksize_, | ||||||||||||||||||||||||
virtual_input_depth}; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
std::vector<size_t> permutation = is_dcr_ ? std::vector<size_t>{0, 1, 3, 2, 4, 5} | ||||||||||||||||||||||||
: std::vector<size_t>{0, 1, 4, 2, 5, 3}; | ||||||||||||||||||||||||
Check warning on line 199 in onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (input.IsDataType<uint8_t>()) { | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return Transpose::DoTranspose( | ||||||||||||||||||||||||
permutation, input, output, &virtual_input_shape, &virtual_output_shape, context->GetOperatorThreadPool()); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Comment on lines
+201
to
+205
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
// user will not see this as the kernel doesn't claim support for types other than float and double | ||||||||||||||||||||||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input type in DepthToSpace (channels_last = 1) op: ", input.DataType()); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return Status::OK(); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, | ||||||||||||||||||||||||
batch, | ||||||||||||||||||||||||
input_depth, input_height, input_width, | ||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.