Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NHWC DepthToSpace U8 and its transformation #23784

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ inline Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
std::vector<size_t> permutations({0, 2, 1, 3});
gsl::span<const size_t> permutations_span{permutations};
size_t from = 2, to = 1;
SingleAxisTranspose(permutations, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
SingleAxisTranspose(permutations, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, nullptr, tp);
return Status::OK();
}

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Quick
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);

// ******** Start: Quantization ******************* //
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DepthToSpace);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConcat);
Expand Down Expand Up @@ -216,6 +217,7 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConcat)>,
Expand Down
16 changes: 11 additions & 5 deletions onnxruntime/core/framework/transpose_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTranspos

// `input_shape_override` overrides the shape of `input` for compute purposes.
void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
size_t from, size_t to, const TensorShape* input_shape_override = nullptr,
size_t from, size_t to,
const TensorShape* input_shape_override = nullptr,
const TensorShape* output_shape_override = nullptr,
concurrency::ThreadPool* tp = nullptr) {
ORT_UNUSED_PARAMETER(permutations);

const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
const auto& input_dims = input_shape.GetDims();
const auto& output_shape = output_shape_override ? *output_shape_override : output.Shape();

const auto element_size = input.DataType()->Size();

Expand Down Expand Up @@ -106,7 +109,7 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
default: {
TensorPitches src_strides(input_dims);

TensorPitches contig_dst_strides(output);
TensorPitches contig_dst_strides(output_shape);

const auto dims = input_dims.size();
TensorShapeVector dst_strides(dims);
Expand Down Expand Up @@ -231,10 +234,13 @@ void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tens
}

// `input_shape_override` overrides the shape of `input` for compute purposes.
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output, size_t from,
size_t to, const TensorShape* input_shape_override, concurrency::ThreadPool* tp) {
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
size_t from, size_t to,
const TensorShape* input_shape_override, const TensorShape* output_shape_override,
concurrency::ThreadPool* tp) {
if (from > to) {
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override, tp);
TransposeSingleAxisOutwards(permutations, input, output, from, to,
input_shape_override, output_shape_override, tp);
} else {
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override);
}
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/framework/transpose_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ We fall back to the default implementation in all other cases, and if the input

namespace onnxruntime {
bool IsTransposeMovingSingleAxis(gsl::span<const size_t> permutations, size_t& from, size_t& to);
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output, size_t from,
size_t to, const TensorShape* input_shape_override = nullptr,
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
size_t from, size_t to,
const TensorShape* input_shape_override = nullptr,
const TensorShape* output_shape_override = nullptr,
concurrency::ThreadPool* tp = nullptr);
} // namespace onnxruntime
76 changes: 76 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3683,6 +3683,82 @@
}
});

static const char* DepthToSpace_ver1_doc = R"DOC(
It is similar to DepthToSpace (https://github.com/onnx/onnx/blob/main/docs/Operators.md#DepthToSpace) with differences:
1. It has additional attribute channels_last.
2. Input and output data type is uint8.
)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(DepthToSpace)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(DepthToSpace_ver1_doc)
.Attr("blocksize", "Blocks of [blocksize, blocksize] are moved.", AttributeProto::INT)
.Attr(
"channels_last",
"1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 0.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"mode",
"DCR (default) for depth-column-row order re-arrangement. Use CRD for column-row-depth order.",
AttributeProto::STRING,
std::string("DCR"))

Check warning on line 3706 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/graph/contrib_ops/contrib_defs.cc:3706: Add #include <string> for string [build/include_what_you_use] [4]
.Input(
0,
"input",
"Input data tensor. Dimensions are [N,H,W,C] when channels_last is 1 or [N,C,H,W] otherwise, where N is the"
"batch axis, C is the channel or depth, H is the height and W is the width.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Output(
0,
"output",
"Output data tensor. Dimensions are [N, H * blocksize, W * blocksize, C/(blocksize * blocksize)] when"
"channels_last is 1 or [N, C/(blocksize * blocksize), H * blocksize, W * blocksize] otherwise.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.TypeConstraint("T", {"tensor(uint8)"}, "")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
auto blocksize = getAttribute(ctx, "blocksize", 0);
if (blocksize <= 0) {
fail_shape_inference("Blocksize must be positive");
}
if (hasInputShape(ctx, 0)) {
auto& input_shape = getInputShape(ctx, 0);
if (input_shape.dim_size() == 4) {
// TODO: Clarify what behavior should be if C is not a multiple of

Check warning on line 3737 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/graph/contrib_ops/contrib_defs.cc:3737: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// blocksize*blocksize.
if (getAttribute(ctx, "channels_last", 0) == 0) {
updateOutputShape(
ctx,
0,
{input_shape.dim(0),
input_shape.dim(1) / (blocksize * blocksize),
input_shape.dim(2) * blocksize,
input_shape.dim(3) * blocksize});
} else { // channels_last
updateOutputShape(
Comment on lines +3743 to +3748
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{input_shape.dim(0),
input_shape.dim(1) / (blocksize * blocksize),
input_shape.dim(2) * blocksize,
input_shape.dim(3) * blocksize});
} else { // channels_last
updateOutputShape(
{input_shape.dim(0),
input_shape.dim(1) / (blocksize * blocksize),
input_shape.dim(2) * blocksize,
input_shape.dim(3) * blocksize});
} else { // channels_last
updateOutputShape(

ctx,
0,
{input_shape.dim(0),
input_shape.dim(1) * blocksize,
input_shape.dim(2) * blocksize,
input_shape.dim(3) / (blocksize * blocksize)});
}
Comment on lines +3751 to +3755
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{input_shape.dim(0),
input_shape.dim(1) * blocksize,
input_shape.dim(2) * blocksize,
input_shape.dim(3) / (blocksize * blocksize)});
}
{input_shape.dim(0),
input_shape.dim(1) * blocksize,
input_shape.dim(2) * blocksize,
input_shape.dim(3) / (blocksize * blocksize)});
}

} else {
fail_shape_inference("Input tensor must be 4-dimensional");
}
}
});

#ifdef ENABLE_ATEN
ONNX_CONTRIB_OPERATOR_SCHEMA(ATen)
.SetDomain(kPytorchAtenDomain)
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/optimizer/nhwc_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator,
}
}

{
// uint8 DepthToSpace -> uint8 nhwc DepthToSpace
OpKernelRegistryId depthtospace_uint8{
"DepthToSpace", kMSDomain, 1, {{"T", {DataTypeImpl::GetTensorType<uint8_t>()}}}};
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, depthtospace_uint8.op_type_, depthtospace_uint8.domain_,
depthtospace_uint8.version_, depthtospace_uint8.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
OpIdInfo("DepthToSpace", kOnnxDomain, api::DataType::UINT8),
OpTransformInfo{depthtospace_uint8.op_type_, depthtospace_uint8.domain_, depthtospace_uint8.version_, true});
}
}

{
// fp16 MaxPool -> fp16 nhwc MaxPool
OpKernelRegistryId nhwc_maxpool_fp16{
Expand Down
54 changes: 54 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#endif

#include "core/providers/cpu/tensor/space_depth_ops.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "core/common/eigen_common_wrapper.h"
#include <array>

Expand Down Expand Up @@ -56,6 +57,18 @@
DataTypeImpl::GetTensorType<uint8_t>()}),
DepthToSpace);

namespace contrib {
ONNX_OPERATOR_TYPED_KERNEL_EX(
DepthToSpace,
kMSDomain,
1,
uint8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<uint8_t>()),
DepthToSpace);
}

Check warning on line 70 in onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Namespace should be terminated with "// namespace contrib" [readability/namespace] [5] Raw Output: onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc:70: Namespace should be terminated with "// namespace contrib" [readability/namespace] [5]

// intermediate tensor shapes are:
// (batch, blocksize, blocksize, input_depth / (blocksize * blocksize), input_height, input_width) for DepthToSpace
// (batch, input_depth, input_height / blocksize, blocksize, input_width / blocksize, blocksize) for SpaceToDepth
Expand Down Expand Up @@ -157,6 +170,47 @@
int64_t output_height = -1;
int64_t output_width = -1;

if (is_nhwc_) {
ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc<true>(input,
batch,
input_depth, input_height, input_width,
output_depth, output_height, output_width,
false));

Comment on lines +174 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc<true>(input,
batch,
input_depth, input_height, input_width,
output_depth, output_height, output_width,
false));
ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc<true>(input,
batch,
input_depth, input_height, input_width,
output_depth, output_height, output_width,
false));

Tensor& output = *context->Output(0, {batch, output_height, output_width, output_depth});

int64_t virtual_input_depth = input_depth / blocksize_ / blocksize_;

TensorShape virtual_input_shape;
if (is_dcr_) {
virtual_input_shape = TensorShape{batch, input_height, input_width,
blocksize_, blocksize_, virtual_input_depth};
} else {
virtual_input_shape = TensorShape{batch, input_height, input_width,
virtual_input_depth, blocksize_, blocksize_};
}

TensorShape virtual_output_shape = TensorShape{batch,
input_height, blocksize_,
input_width, blocksize_,
virtual_input_depth};

std::vector<size_t> permutation = is_dcr_ ? std::vector<size_t>{0, 1, 3, 2, 4, 5}
: std::vector<size_t>{0, 1, 4, 2, 5, 3};

Check warning on line 199 in onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc:199: Add #include <vector> for vector<> [build/include_what_you_use] [4]

if (input.IsDataType<uint8_t>()) {

return Transpose::DoTranspose(
permutation, input, output, &virtual_input_shape, &virtual_output_shape, context->GetOperatorThreadPool());

Comment on lines +201 to +205
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (input.IsDataType<uint8_t>()) {
return Transpose::DoTranspose(
permutation, input, output, &virtual_input_shape, &virtual_output_shape, context->GetOperatorThreadPool());
if (input.IsDataType<uint8_t>()) {
return Transpose::DoTranspose(
permutation, input, output, &virtual_input_shape, &virtual_output_shape, context->GetOperatorThreadPool());

} else {
// user will not see this as the kernel doesn't claim support for types other than float and double
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input type in DepthToSpace (channels_last = 1) op: ", input.DataType());
}

return Status::OK();
}

ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input,
batch,
input_depth, input_height, input_width,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/space_depth_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class SpaceToDepth final : public OpKernel, SpaceDepthBase {
class DepthToSpace final : public OpKernel, SpaceDepthBase {
public:
explicit DepthToSpace(const OpKernelInfo& info) : OpKernel(info), SpaceDepthBase(info) {
is_nhwc_ = (info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(0)) != 0);
std::string mode;
// if mode doesn't exist, then it is the default "DCR" mode
// (or) it is an opset < 11 model for which the only mode is "DCR" mode
Expand All @@ -95,6 +96,7 @@ class DepthToSpace final : public OpKernel, SpaceDepthBase {

private:
bool is_dcr_ = true;
bool is_nhwc_ = false;
};

} // namespace onnxruntime
Loading
Loading