Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 101 additions & 9 deletions onnxruntime/core/optimizer/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ bool IsOperationDeterministic(const std::string& domain, const std::string& op)

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

bool IsScalarOr1Element1DTensor(gsl::span<const int64_t> tensor_shape) {
const size_t rank = tensor_shape.size();
return (rank == 0) || ((rank == 1) && (tensor_shape[0] == 1));
}

bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, float& max) {
min = std::numeric_limits<float>::lowest();
max = std::numeric_limits<float>::max();
Expand All @@ -330,28 +335,115 @@ bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, flo
return true;
}

bool is_constant = true;
bool is_constant = false;
const ONNX_NAMESPACE::TensorProto* initializer = graph.GetConstantInitializer(input->Name(), true);
if (initializer) {
Initializer i(graph, *initializer, graph.ModelPath());
switch (initializer->data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
value = *i.data<float>();
is_constant = true;
break;
// double isn't currently supported
// case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
// value = static_cast<float>(*i.data<double>());
// break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
value = math::halfToFloat(i.data<MLFloat16>()->val);
is_constant = true;
break;
default:
ORT_THROW("Unexpected data type for Clip input of ", initializer->data_type());
is_constant = false;
break;
}
return is_constant;
}
const Node* producer = graph.GetProducerNode(input->Name());
if (producer && producer->OpType() == "DequantizeLinear") {
const auto& dq_inputs = producer->InputDefs();
const ONNX_NAMESPACE::TensorProto* dq_input = graph.GetConstantInitializer(dq_inputs[0]->Name(), true);
const ONNX_NAMESPACE::TensorProto* dq_scale = graph.GetConstantInitializer(dq_inputs[1]->Name(), true);
const ONNX_NAMESPACE::TensorProto* dq_zero_point = graph.GetConstantInitializer(dq_inputs[2]->Name(), true);
if (!dq_input || !dq_scale || !dq_zero_point) {
return false;
}
// Check scale and zero_point are scalar
Initializer scale_initializer(graph, *dq_scale, graph.ModelPath());
Initializer zero_point_initializer(graph, *dq_zero_point, graph.ModelPath());
if (!IsScalarOr1Element1DTensor(scale_initializer.dims()) || !IsScalarOr1Element1DTensor(zero_point_initializer.dims())) {
return false;
}
float scale = 1.0f;
float zero_point = 0.0f;
// Get scale
switch (dq_scale->data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
scale = *scale_initializer.data<float>();
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
scale = math::halfToFloat(scale_initializer.data<MLFloat16>()->val);
break;
}
default:
return false;
}
// Get zero_point
switch (dq_zero_point->data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
zero_point = static_cast<float>(*zero_point_initializer.data<uint8_t>());
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
zero_point = static_cast<float>(*zero_point_initializer.data<int8_t>());
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
zero_point = static_cast<float>(*zero_point_initializer.data<uint16_t>());
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
zero_point = static_cast<float>(*zero_point_initializer.data<int16_t>());
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
zero_point = static_cast<float>(*zero_point_initializer.data<int32_t>());
break;
}
default:
return false;
}
// Restore original input value
Initializer x_initializer(graph, *dq_input, graph.ModelPath());
if (!IsScalarOr1Element1DTensor(x_initializer.dims())) {
return false;
}
switch (dq_input->data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
value = scale * (static_cast<float>(*x_initializer.data<uint8_t>()) - zero_point);
is_constant = true;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
value = scale * (static_cast<float>(*x_initializer.data<int8_t>()) - zero_point);
is_constant = true;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
value = scale * (static_cast<float>(*x_initializer.data<uint16_t>()) - zero_point);
is_constant = true;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
value = scale * (static_cast<float>(*x_initializer.data<int16_t>()) - zero_point);
is_constant = true;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
value = scale * (static_cast<float>(*x_initializer.data<int32_t>()) - zero_point);
is_constant = true;
break;
}
default:
return false;
}
} else {
is_constant = false;
}

return is_constant;
};

Expand Down
134 changes: 134 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4197,6 +4197,140 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
}
}

TEST(QDQTransformerTests, QDQ_Selector_Test_ConvClip) {
const auto& logger = DefaultLoggingManager().DefaultLogger();

auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<uint8_t>({1, 2, 4, 4}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* weight_arg = builder.MakeInput<uint8_t>({2, 1, 3, 3}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* bias_arg =
builder.MakeInput<int32_t>({2}, std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max());
auto* dq_input = builder.MakeIntermediate();
auto* dq_weight = builder.MakeIntermediate();
auto* dq_bias = builder.MakeIntermediate();
builder.AddDequantizeLinearNode(input_arg, 0.02348f, uint8_t(0), dq_input, false);
builder.AddDequantizeLinearNode(weight_arg, 0.307f, uint8_t(0), dq_weight, false);
builder.AddDequantizeLinearNode(bias_arg, 0.007f, int32_t(0), dq_bias, false);

// Conv
auto* conv_output = builder.MakeIntermediate();
Node& conv_node = builder.AddNode("Conv", {dq_input, dq_weight, dq_bias}, {conv_output});
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
conv_node.AddAttribute("strides", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("group", int64_t(2));
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});

// Clip
NodeArg* clip_min = builder.MakeScalarInitializer<uint8_t>(128); // -> 0.0f
NodeArg* clip_max = builder.MakeScalarInitializer<uint8_t>(255); // -> 0.6f
NodeArg* min_dq = builder.MakeIntermediate();
NodeArg* max_dq = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<uint8_t>(clip_min, 0.00784313772f, static_cast<uint8_t>(128), min_dq, false);
builder.AddDequantizeLinearNode<uint8_t>(clip_max, 0.0235293377f, static_cast<uint8_t>(0), max_dq, false);
NodeArg* clip_fp32 = builder.MakeIntermediate();
builder.AddNode("Clip", {conv_output, min_dq, max_dq}, {clip_fp32});
NodeArg* clip_q = builder.MakeIntermediate();
NodeArg* clip_dq = builder.MakeOutput();
builder.AddQuantizeLinearNode<uint8_t>(clip_fp32, 0.0082940589f, static_cast<uint8_t>(0), clip_q, false);
builder.AddDequantizeLinearNode<uint8_t>(clip_q, 0.0082940589f, static_cast<uint8_t>(0), clip_dq, false);
};
// Build the model for this test.
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = 18;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());
const GraphViewer whole_graph_viewer(graph);

// Make sure node 3 is the conv node
const auto* conv_node = graph.GetNode(3);
ASSERT_TRUE(nullptr != conv_node);
ASSERT_EQ("Conv", conv_node->OpType());

// Make sure the conv QDQ group is selected
onnxruntime::QDQ::ConvNodeGroupSelector conv_selector;
const auto result = conv_selector.GetQDQSelection(whole_graph_viewer, *conv_node);
ASSERT_TRUE(result.has_value());
const auto& qdq_group = *result;
ASSERT_EQ(NodeIndex(3), qdq_group.target_node);
ASSERT_EQ(NodeIndex(6), qdq_group.redundant_clip_node);
}

TEST(QDQTransformerTests, QDQ_Selector_Test_ConvClipNonScalar) {
const auto& logger = DefaultLoggingManager().DefaultLogger();

auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<uint8_t>({1, 2, 4, 4}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* weight_arg = builder.MakeInput<uint8_t>({2, 1, 3, 3}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* bias_arg =
builder.MakeInput<int32_t>({2}, std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max());
auto* dq_input = builder.MakeIntermediate();
auto* dq_weight = builder.MakeIntermediate();
auto* dq_bias = builder.MakeIntermediate();
builder.AddDequantizeLinearNode(input_arg, 0.02348f, uint8_t(0), dq_input, false);
builder.AddDequantizeLinearNode(weight_arg, 0.307f, uint8_t(0), dq_weight, false);
builder.AddDequantizeLinearNode(bias_arg, 0.007f, int32_t(0), dq_bias, false);

// Conv
auto* conv_output = builder.MakeIntermediate();
Node& conv_node = builder.AddNode("Conv", {dq_input, dq_weight, dq_bias}, {conv_output});
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
conv_node.AddAttribute("strides", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("group", int64_t(2));
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});

// Clip
NodeArg* clip_min = builder.Make1DInitializer<uint8_t>({128}); // -> 0.0f
NodeArg* clip_max = builder.Make1DInitializer<uint8_t>({255}); // -> 0.6f
NodeArg* min_dq = builder.MakeIntermediate();
NodeArg* max_dq = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<uint8_t>(clip_min, 0.00784313772f, static_cast<uint8_t>(128), min_dq, false);
builder.AddDequantizeLinearNode<uint8_t>(clip_max, 0.0235293377f, static_cast<uint8_t>(0), max_dq, false);
NodeArg* clip_fp32 = builder.MakeIntermediate();
builder.AddNode("Clip", {conv_output, min_dq, max_dq}, {clip_fp32});
NodeArg* clip_q = builder.MakeIntermediate();
NodeArg* clip_dq = builder.MakeOutput();
builder.AddQuantizeLinearNode<uint8_t>(clip_fp32, 0.0082940589f, static_cast<uint8_t>(0), clip_q, false);
builder.AddDequantizeLinearNode<uint8_t>(clip_q, 0.0082940589f, static_cast<uint8_t>(0), clip_dq, false);
};
// Build the model for this test.
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = 18;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());
const GraphViewer whole_graph_viewer(graph);

// Make sure node 3 is the conv node
const auto* conv_node = graph.GetNode(3);
ASSERT_TRUE(nullptr != conv_node);
ASSERT_EQ("Conv", conv_node->OpType());

// Make sure the conv QDQ group is selected
onnxruntime::QDQ::ConvNodeGroupSelector conv_selector;
const auto result = conv_selector.GetQDQSelection(whole_graph_viewer, *conv_node);
ASSERT_TRUE(result.has_value());
const auto& qdq_group = *result;
ASSERT_EQ(NodeIndex(3), qdq_group.target_node);
ASSERT_EQ(NodeIndex(6), qdq_group.redundant_clip_node);
}

TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
const auto& logger = DefaultLoggingManager().DefaultLogger();

Expand Down
Loading