diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc index c8fe9c77d8ff8..a8b0bb0193240 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc @@ -27,7 +27,7 @@ inline Status Transpose_BSNH_to_BNSH(const Tensor* qkv, std::vector permutations({0, 2, 1, 3}); gsl::span permutations_span{permutations}; size_t from = 2, to = 1; - SingleAxisTranspose(permutations, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp); + SingleAxisTranspose(permutations, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, nullptr, tp); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index c742cd1e95bdd..c38582b36239f 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -64,6 +64,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Quick class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); // ******** Start: Quantization ******************* // +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConcat); @@ -216,6 +217,7 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/transpose_helper.cc b/onnxruntime/core/framework/transpose_helper.cc index 32d15bdf9060b..b1a5b85fe84db 100644 --- a/onnxruntime/core/framework/transpose_helper.cc +++ b/onnxruntime/core/framework/transpose_helper.cc @@ -59,12 +59,15 @@ typename std::enable_if::value, void>::type SimpleTranspos // `input_shape_override` overrides the shape of `input` for compute purposes. void TransposeSingleAxisOutwards(gsl::span permutations, const Tensor& input, Tensor& output, - size_t from, size_t to, const TensorShape* input_shape_override = nullptr, + size_t from, size_t to, + const TensorShape* input_shape_override = nullptr, + const TensorShape* output_shape_override = nullptr, concurrency::ThreadPool* tp = nullptr) { ORT_UNUSED_PARAMETER(permutations); const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape(); const auto& input_dims = input_shape.GetDims(); + const auto& output_shape = output_shape_override ? *output_shape_override : output.Shape(); const auto element_size = input.DataType()->Size(); @@ -106,7 +109,7 @@ void TransposeSingleAxisOutwards(gsl::span permutations, const Ten default: { TensorPitches src_strides(input_dims); - TensorPitches contig_dst_strides(output); + TensorPitches contig_dst_strides(output_shape); const auto dims = input_dims.size(); TensorShapeVector dst_strides(dims); @@ -231,10 +234,13 @@ void TransposeSingleAxisInwards(gsl::span permutations, const Tens } // `input_shape_override` overrides the shape of `input` for compute purposes. -void SingleAxisTranspose(gsl::span permutations, const Tensor& input, Tensor& output, size_t from, - size_t to, const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { +void SingleAxisTranspose(gsl::span permutations, const Tensor& input, Tensor& output, + size_t from, size_t to, + const TensorShape* input_shape_override, const TensorShape* output_shape_override, + concurrency::ThreadPool* tp) { if (from > to) { - TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override, tp); + TransposeSingleAxisOutwards(permutations, input, output, from, to, + input_shape_override, output_shape_override, tp); } else { TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override); } diff --git a/onnxruntime/core/framework/transpose_helper.h b/onnxruntime/core/framework/transpose_helper.h index e33044117f89a..16f5f8c9aa193 100644 --- a/onnxruntime/core/framework/transpose_helper.h +++ b/onnxruntime/core/framework/transpose_helper.h @@ -41,7 +41,9 @@ We fall back to the default implementation in all other cases, and if the input namespace onnxruntime { bool IsTransposeMovingSingleAxis(gsl::span permutations, size_t& from, size_t& to); -void SingleAxisTranspose(gsl::span permutations, const Tensor& input, Tensor& output, size_t from, - size_t to, const TensorShape* input_shape_override = nullptr, +void SingleAxisTranspose(gsl::span permutations, const Tensor& input, Tensor& output, + size_t from, size_t to, + const TensorShape* input_shape_override = nullptr, + const TensorShape* output_shape_override = nullptr, concurrency::ThreadPool* tp = nullptr); } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e45787299f3ad..853bd0abae825 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3683,6 +3683,82 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h } }); + static const char* DepthToSpace_ver1_doc = R"DOC( +It is similar to DepthToSpace (https://github.com/onnx/onnx/blob/main/docs/Operators.md#DepthToSpace) with differences: + 1. It has additional attribute channels_last. + 2. Input and output data type is uint8. +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(DepthToSpace) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(DepthToSpace_ver1_doc) + .Attr("blocksize", "Blocks of [blocksize, blocksize] are moved.", AttributeProto::INT) + .Attr( + "channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 0.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "mode", + "DCR (default) for depth-column-row order re-arrangement. Use CRD for column-row-depth order.", + AttributeProto::STRING, + std::string("DCR")) + .Input( + 0, + "input", + "Input data tensor. Dimensions are [N,H,W,C] when channels_last is 1 or [N,C,H,W] otherwise, where N is the" + "batch axis, C is the channel or depth, H is the height and W is the width.", + "T", + OpSchema::Single, + true, + 1, + OpSchema::Differentiable) + .Output( + 0, + "output", + "Output data tensor. Dimensions are [N, H * blocksize, W * blocksize, C/(blocksize * blocksize)] when" + "channels_last is 1 or [N, C/(blocksize * blocksize), H * blocksize, W * blocksize] otherwise.", + "T", + OpSchema::Single, + true, + 1, + OpSchema::Differentiable) + .TypeConstraint("T", {"tensor(uint8)"}, "") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + auto blocksize = getAttribute(ctx, "blocksize", 0); + if (blocksize <= 0) { + fail_shape_inference("Blocksize must be positive"); + } + if (hasInputShape(ctx, 0)) { + auto& input_shape = getInputShape(ctx, 0); + if (input_shape.dim_size() == 4) { + // TODO: Clarify what behavior should be if C is not a multiple of + // blocksize*blocksize. + if (getAttribute(ctx, "channels_last", 0) == 0) { + updateOutputShape( + ctx, + 0, + {input_shape.dim(0), + input_shape.dim(1) / (blocksize * blocksize), + input_shape.dim(2) * blocksize, + input_shape.dim(3) * blocksize}); + } else { // channels_last + updateOutputShape( + ctx, + 0, + {input_shape.dim(0), + input_shape.dim(1) * blocksize, + input_shape.dim(2) * blocksize, + input_shape.dim(3) / (blocksize * blocksize)}); + } + } else { + fail_shape_inference("Input tensor must be 4-dimensional"); + } + } + }); + #ifdef ENABLE_ATEN ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) .SetDomain(kPytorchAtenDomain) diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index cd654991c92d5..3516c600bac41 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -117,6 +117,22 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, } } + { + // uint8 DepthToSpace -> uint8 nhwc DepthToSpace + OpKernelRegistryId depthtospace_uint8{ + "DepthToSpace", kMSDomain, 1, {{"T", {DataTypeImpl::GetTensorType()}}}}; + const KernelCreateInfo* kernel_create_info{}; + const auto status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, depthtospace_uint8.op_type_, depthtospace_uint8.domain_, + depthtospace_uint8.version_, depthtospace_uint8.type_constraints_, logger, &kernel_create_info); + if (status.IsOK() && kernel_create_info != nullptr) { + kernel_create_info = nullptr; + conv_table_.emplace( + OpIdInfo("DepthToSpace", kOnnxDomain, api::DataType::UINT8), + OpTransformInfo{depthtospace_uint8.op_type_, depthtospace_uint8.domain_, depthtospace_uint8.version_, true}); + } + } + { // fp16 MaxPool -> fp16 nhwc MaxPool OpKernelRegistryId nhwc_maxpool_fp16{ diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc index 7e1049c402210..5b6e6f4983f57 100644 --- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc +++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc @@ -7,6 +7,7 @@ #endif #include "core/providers/cpu/tensor/space_depth_ops.h" +#include "core/providers/cpu/tensor/transpose.h" #include "core/common/eigen_common_wrapper.h" #include @@ -56,6 +57,18 @@ ONNX_CPU_OPERATOR_KERNEL( DataTypeImpl::GetTensorType()}), DepthToSpace); +namespace contrib { +ONNX_OPERATOR_TYPED_KERNEL_EX( + DepthToSpace, + kMSDomain, + 1, + uint8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DepthToSpace); +} + // intermediate tensor shapes are: // (batch, blocksize, blocksize, input_depth / (blocksize * blocksize), input_height, input_width) for DepthToSpace // (batch, input_depth, input_height / blocksize, blocksize, input_width / blocksize, blocksize) for SpaceToDepth @@ -157,6 +170,47 @@ Status DepthToSpace::Compute(OpKernelContext* context) const { int64_t output_height = -1; int64_t output_width = -1; + if (is_nhwc_) { + ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, + batch, + input_depth, input_height, input_width, + output_depth, output_height, output_width, + false)); + + Tensor& output = *context->Output(0, {batch, output_height, output_width, output_depth}); + + int64_t virtual_input_depth = input_depth / blocksize_ / blocksize_; + + TensorShape virtual_input_shape; + if (is_dcr_) { + virtual_input_shape = TensorShape{batch, input_height, input_width, + blocksize_, blocksize_, virtual_input_depth}; + } else { + virtual_input_shape = TensorShape{batch, input_height, input_width, + virtual_input_depth, blocksize_, blocksize_}; + } + + TensorShape virtual_output_shape = TensorShape{batch, + input_height, blocksize_, + input_width, blocksize_, + virtual_input_depth}; + + std::vector permutation = is_dcr_ ? std::vector{0, 1, 3, 2, 4, 5} + : std::vector{0, 1, 4, 2, 5, 3}; + + if (input.IsDataType()) { + + return Transpose::DoTranspose( + permutation, input, output, &virtual_input_shape, &virtual_output_shape, context->GetOperatorThreadPool()); + + } else { + // user will not see this as the kernel doesn't claim support for types other than float and double + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input type in DepthToSpace (channels_last = 1) op: ", input.DataType()); + } + + return Status::OK(); + } + ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, batch, input_depth, input_height, input_width, diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h index 3218c8952d6ec..d2676c2cc4891 100644 --- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h @@ -79,6 +79,7 @@ class SpaceToDepth final : public OpKernel, SpaceDepthBase { class DepthToSpace final : public OpKernel, SpaceDepthBase { public: explicit DepthToSpace(const OpKernelInfo& info) : OpKernel(info), SpaceDepthBase(info) { + is_nhwc_ = (info.GetAttrOrDefault("channels_last", static_cast(0)) != 0); std::string mode; // if mode doesn't exist, then it is the default "DCR" mode // (or) it is an opset < 11 model for which the only mode is "DCR" mode @@ -95,6 +96,7 @@ class DepthToSpace final : public OpKernel, SpaceDepthBase { private: bool is_dcr_ = true; + bool is_nhwc_ = false; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index 5b904e85848d0..e9427716b0c01 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -261,9 +261,12 @@ static void DoTransposeEltWise(int64_t num_axes, gsl::span target // `input_shape_override` overrides the shape of `input` for compute purposes. static Status DoUntypedTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override = nullptr) { + const TensorShape* input_shape_override = nullptr, + const TensorShape* output_shape_override = nullptr) { const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape(); const auto& input_dims = input_shape.GetDims(); + const auto& output_shape = output_shape_override ? *output_shape_override : output.Shape(); + const auto& output_dims = output_shape.GetDims(); auto rank = input_shape.NumDimensions(); const auto element_size = input.DataType()->Size(); @@ -307,10 +310,10 @@ static Status DoUntypedTranspose(const gsl::span& permutations, co if (1 == prefix_blocksize) { DoTransposeSingleBlock(suffix_blocksize, input_data, output_data); } else if (1 == suffix_blocksize) { - DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride, + DoTransposeEltWise(num_axes_in_prefix, output_dims, prefix_blocksize, stride, input_data, output_data); } else { - DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride, + DoTransposeImpl(num_axes_in_prefix, output_dims, prefix_blocksize, suffix_blocksize, stride, input_data, output_data); } } else { @@ -323,10 +326,10 @@ static Status DoUntypedTranspose(const gsl::span& permutations, co DoTransposeSingleBlock(suffix_blocksize, input_data, output_data, element_size); } else if (1 == suffix_blocksize) { // this may return a failed status if the data size is not supported in this build - status = DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride, + status = DoTransposeEltWise(num_axes_in_prefix, output_dims, prefix_blocksize, stride, input_data, output_data, element_size); } else { - DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride, + DoTransposeImpl(num_axes_in_prefix, output_dims, prefix_blocksize, suffix_blocksize, stride, input_data, output_data, element_size); } } @@ -349,7 +352,8 @@ bool IsTransposeReshape(const gsl::span& perm, gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { + const TensorShape* input_shape_override, const TensorShape* output_shape_override, + concurrency::ThreadPool* tp) { TensorShape shape = input_shape_override ? *input_shape_override : input.Shape(); if (IsTransposeReshape(permutations, shape.GetDims())) { @@ -363,12 +367,12 @@ static Status TransposeImpl(const gsl::span& permutations, const T bool moving_single_axis = IsTransposeMovingSingleAxis(permutations, from, to); if (moving_single_axis && !input.IsDataTypeString()) { - SingleAxisTranspose(permutations, input, output, from, to, input_shape_override, tp); + SingleAxisTranspose(permutations, input, output, from, to, input_shape_override, output_shape_override, tp); return Status::OK(); } // fall back to default implementation - return DoUntypedTranspose(permutations, input, output, input_shape_override); + return DoUntypedTranspose(permutations, input, output, input_shape_override, output_shape_override); } template @@ -388,7 +392,8 @@ static Status UnpackInt4Tensor(const Tensor& src, Tensor& dst, AllocatorPtr cpu_ template static Status DoTransposeInt4(const gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { + const TensorShape* input_shape_override, const TensorShape* output_shape_override, + concurrency::ThreadPool* tp) { using Int8Type = typename Int4Type::UnpackedType; ORT_RETURN_IF_NOT(input.IsDataType() && output.IsDataType(), @@ -400,7 +405,7 @@ static Status DoTransposeInt4(const gsl::span& permutations, const Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); ORT_RETURN_IF_ERROR((UnpackInt4Tensor(input, input_unpacked, cpu_allocator))); - ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, tp)); + ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, output_shape_override, tp)); ORT_RETURN_IF_NOT(Int4Type::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), "Failed to pack 8-bit Tensor into 4-bit Tensor"); @@ -409,7 +414,8 @@ static Status DoTransposeInt4(const gsl::span& permutations, const //`input_shape_override` overrides the shape of `input` for compute purposes. Status TransposeBase::DoTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { + const TensorShape* input_shape_override, const TensorShape* output_shape_override, + concurrency::ThreadPool* tp) { auto input_type = input.DataType(); auto output_type = output.DataType(); @@ -418,14 +424,14 @@ Status TransposeBase::DoTranspose(const gsl::span& permutations, c input_type, " != ", output_type); } if (input.IsDataType()) { - return DoTransposeInt4(permutations, input, output, input_shape_override, tp); + return DoTransposeInt4(permutations, input, output, input_shape_override, output_shape_override, tp); } if (input.IsDataType()) { - return DoTransposeInt4(permutations, input, output, input_shape_override, tp); + return DoTransposeInt4(permutations, input, output, input_shape_override, output_shape_override, tp); } - return TransposeImpl(permutations, input, output, input_shape_override, tp); + return TransposeImpl(permutations, input, output, input_shape_override, output_shape_override, tp); } Status Transpose::Compute(OpKernelContext* ctx) const { @@ -450,7 +456,7 @@ Status Transpose::Compute(OpKernelContext* ctx) const { return Status::OK(); } - return DoTranspose(*p_perm, X, Y, nullptr, ctx->GetOperatorThreadPool()); + return DoTranspose(*p_perm, X, Y, nullptr, nullptr, ctx->GetOperatorThreadPool()); } ONNX_CPU_OPERATOR_VERSIONED_KERNEL( diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.h b/onnxruntime/core/providers/cpu/tensor/transpose.h index 54d3584ba0dad..f14282986a119 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.h +++ b/onnxruntime/core/providers/cpu/tensor/transpose.h @@ -34,6 +34,7 @@ class TransposeBase { */ static Status DoTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, const TensorShape* input_shape_override = nullptr, + const TensorShape* output_shape_override = nullptr, concurrency::ThreadPool* tp = nullptr); protected: diff --git a/onnxruntime/test/contrib_ops/depth_to_space_op_test.cc b/onnxruntime/test/contrib_ops/depth_to_space_op_test.cc new file mode 100644 index 0000000000000..bf247b16291ab --- /dev/null +++ b/onnxruntime/test/contrib_ops/depth_to_space_op_test.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/execution_provider.h" +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime { +namespace test { + +template +void RunDepthToSpace(const std::vector& input, + const std::vector& input_shape, + const int64_t blocksize, + const int64_t channels_last, + const std::string mode, + const std::vector& output, + const std::vector& output_shape, + OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess) { + auto run_test = [&]() { + OpTester test("DepthToSpace", 1, kMSDomain); + + test.AddAttribute("blocksize", blocksize); + test.AddAttribute("channels_last", channels_last); + test.AddAttribute("mode", mode); + + test.AddInput("input", input_shape, input); + test.AddOutput("output", output_shape, output); + + std::vector> eps; + eps.push_back(DefaultCpuExecutionProvider()); + test.Run(expect_result, "", {}, nullptr, &eps); + }; + + run_test(); +} + +TEST(DepthToSpaceOpTest, ContribDCR) { + + constexpr int64_t N = 2, H = 3, W = 2, C = 12; + constexpr int64_t blocksize = 2; + std::vector input = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + + + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, + 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143 + }; + std::vector input_shape = {N, H, W, C}; + std::vector output = { + 0, 1, 2, + 3, 4, 5, + 12, 13, 14, + 15, 16, 17, + + 6, 7, 8, + 9, 10, 11, + 18, 19, 20, + 21, 22, 23, + + 24, 25, 26, + 27, 28, 29, + 36, 37, 38, + 39, 40, 41, + + 30, 31, 32, + 33, 34, 35, + 42, 43, 44, + 45, 46, 47, + + 48, 49, 50, + 51, 52, 53, + 60, 61, 62, + 63, 64, 65, + + 54, 55, 56, + 57, 58, 59, + 66, 67, 68, + 69, 70, 71, + + + 72, 73, 74, + 75, 76, 77, + 84, 85, 86, + 87, 88, 89, + + 78, 79, 80, + 81, 82, 83, + 90, 91, 92, + 93, 94, 95, + + 96, 97, 98, + 99, 100, 101, + 108, 109, 110, + 111, 112, 113, + + 102, 103, 104, + 105, 106, 107, + 114, 115, 116, + 117, 118, 119, + + 120, 121, 122, + 123, 124, 125, + 132, 133, 134, + 135, 136, 137, + + 126, 127, 128, + 129, 130, 131, + 138, 139, 140, + 141, 142, 143 + }; + std::vector output_shape = {N, H * blocksize, W * blocksize, C / (blocksize * blocksize)}; + + RunDepthToSpace(input, input_shape, blocksize, 1, "DCR", output, output_shape); +} + +TEST(DepthToSpaceOpTest, ContribCRD) { + + constexpr int64_t N = 2, H = 3, W = 2, C = 12; + constexpr int64_t blocksize = 2; + std::vector input = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + + + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, + 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143 + }; + std::vector input_shape = {N, H, W, C}; + std::vector output = { + 0, 4, 8, + 1, 5, 9, + 12, 16, 20, + 13, 17, 21, + + 2, 6, 10, + 3, 7, 11, + 14, 18, 22, + 15, 19, 23, + + 24, 28, 32, + 25, 29, 33, + 36, 40, 44, + 37, 41, 45, + + 26, 30, 34, + 27, 31, 35, + 38, 42, 46, + 39, 43, 47, + + 48, 52, 56, + 49, 53, 57, + 60, 64, 68, + 61, 65, 69, + + 50, 54, 58, + 51, 55, 59, + 62, 66, 70, + 63, 67, 71, + + + 72, 76, 80, + 73, 77, 81, + 84, 88, 92, + 85, 89, 93, + + 74, 78, 82, + 75, 79, 83, + 86, 90, 94, + 87, 91, 95, + + 96, 100, 104, + 97, 101, 105, + 108, 112, 116, + 109, 113, 117, + + 98, 102, 106, + 99, 103, 107, + 110, 114, 118, + 111, 115, 119, + + 120, 124, 128, + 121, 125, 129, + 132, 136, 140, + 133, 137, 141, + + 122, 126, 130, + 123, 127, 131, + 134, 138, 142, + 135, 139, 143 + }; + std::vector output_shape = {N, H * blocksize, W * blocksize, C / (blocksize * blocksize)}; + + RunDepthToSpace(input, input_shape, blocksize, 1, "CRD", output, output_shape); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index a247fea7e5f53..7dc46a7a65576 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -8,6 +8,7 @@ #include "graph_transform_test_builder.h" #include "core/mlas/inc/mlas.h" #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" namespace onnxruntime { namespace test { @@ -516,6 +517,36 @@ TEST(NhwcTransformerTests, ConvMixTensorRanks) { TransformerLevel::Level3); } +TEST(NhwcTransformerTests, DepthToSpace) { + auto test_case = [&](const std::vector& input_shape, const int64_t blocksize, const std::string mode) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, 0, 255); + auto* output_arg = builder.MakeOutput(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("blocksize", blocksize), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("mode", mode), attrs); + + builder.AddNode("DepthToSpace", {input_arg}, {output_arg}, "", &attrs); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.DepthToSpace"], 1); + EXPECT_EQ(op_to_count["Transpose"], 2); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3); + }; + + test_case({2, 12, 3, 2}, 2, "DCR"); + test_case({1, 1024, 48, 48}, 4, "DCR"); + test_case({2, 12, 3, 2}, 2, "CRD"); + test_case({1, 1024, 48, 48}, 4, "CRD"); +} + #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED static std::vector ARangeOfFP16Values(const std::vector& shape, MLFloat16 min, MLFloat16 max) {