-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Optional Activation node to NodeUnit #22888
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -27,12 +27,14 @@ struct NodeGroup { | |||||||||||||
std::vector<NodeIndex> dq_nodes; | ||||||||||||||
std::vector<NodeIndex> q_nodes; | ||||||||||||||
NodeIndex target_node; | ||||||||||||||
std::optional<NodeIndex> activation_node; | ||||||||||||||
|
||||||||||||||
// Validator to check if the set of nodes can form a valid QDQ NodeGroup. | ||||||||||||||
// Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to | ||||||||||||||
// be converted into a single node with a quantized operator. | ||||||||||||||
static Status CanCreateNodeGroup(const GraphViewer& graph_viewer, | ||||||||||||||
const Node& target_node, | ||||||||||||||
const Node* p_activation_node, | ||||||||||||||
gsl::span<const Node* const> dq_nodes, | ||||||||||||||
gsl::span<const Node* const> q_nodes); | ||||||||||||||
}; | ||||||||||||||
|
@@ -68,7 +70,7 @@ class NodeUnit { | |||||||||||||
public: | ||||||||||||||
explicit NodeUnit(const Node& node); | ||||||||||||||
explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); | ||||||||||||||
NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node, | ||||||||||||||
NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node, const Node* p_activation_node, | ||||||||||||||
gsl::span<const Node* const> q_nodes, Type unit_type, | ||||||||||||||
gsl::span<const NodeUnitIODef> inputs, gsl::span<const NodeUnitIODef> outputs, | ||||||||||||||
size_t input_edge_count, Node::EdgeSet output_edges); | ||||||||||||||
|
@@ -87,6 +89,7 @@ class NodeUnit { | |||||||||||||
ProviderType GetExecutionProviderType() const noexcept; | ||||||||||||||
|
||||||||||||||
const Node& GetNode() const noexcept { return target_node_; } | ||||||||||||||
const Node* GetActivationNode() const noexcept { return p_activation_node_; } | ||||||||||||||
const std::vector<const Node*>& GetDQNodes() const noexcept { return dq_nodes_; } | ||||||||||||||
const std::vector<const Node*>& GetQNodes() const noexcept { return q_nodes_; } | ||||||||||||||
std::vector<const Node*> GetAllNodesInGroup() const noexcept; | ||||||||||||||
|
@@ -106,6 +109,7 @@ class NodeUnit { | |||||||||||||
|
||||||||||||||
const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs | ||||||||||||||
const Node& target_node_; | ||||||||||||||
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present. | ||||||||||||||
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs | ||||||||||||||
Comment on lines
111
to
113
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
const Type type_; | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -215,4 +215,141 @@ | |||||
|
||||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) | ||||||
|
||||||
namespace { | ||||||
|
||||||
bool GetDataTypeMinMax(int32_t data_type, int32_t& min, int32_t& max) { | ||||||
switch (data_type) { | ||||||
case ONNX_NAMESPACE::TensorProto::INT8: | ||||||
min = static_cast<int32_t>(std::numeric_limits<int8_t>::min()); | ||||||
max = static_cast<int32_t>(std::numeric_limits<int8_t>::max()); | ||||||
break; | ||||||
case ONNX_NAMESPACE::TensorProto::UINT8: | ||||||
min = static_cast<int32_t>(std::numeric_limits<uint8_t>::min()); | ||||||
max = static_cast<int32_t>(std::numeric_limits<uint8_t>::max()); | ||||||
break; | ||||||
case ONNX_NAMESPACE::TensorProto::INT16: | ||||||
min = static_cast<int32_t>(std::numeric_limits<int16_t>::min()); | ||||||
max = static_cast<int32_t>(std::numeric_limits<int16_t>::max()); | ||||||
break; | ||||||
case ONNX_NAMESPACE::TensorProto::UINT16: | ||||||
min = static_cast<int32_t>(std::numeric_limits<uint16_t>::min()); | ||||||
max = static_cast<int32_t>(std::numeric_limits<uint16_t>::max()); | ||||||
break; | ||||||
default: | ||||||
return false; | ||||||
} | ||||||
return true; | ||||||
} | ||||||
bool GetQSalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Would be good to refactor some of the utils here as there seems to be a fair bit of duplication. e.g. maybe a general purpose helper that reads the scale and zp values (scalar or otherwise), and has a bool to indicate if they're scalar. that helper could be used by many of the utils here. |
||||||
int32_t& data_type) { | ||||||
assert(q_node.OpType() == QOpName); | ||||||
const auto& q_input_defs = q_node.InputDefs(); | ||||||
if (q_input_defs.size() != 3 || !q_input_defs[2]->Exists()) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. zp is optional and defaults to zero, so do we need to require 3 inputs here? |
||||||
return false; | ||||||
} | ||||||
|
||||||
const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = | ||||||
graph_viewer.GetConstantInitializer(q_input_defs[1]->Name(), true); | ||||||
if (!scale_tensor_proto) { | ||||||
return false; | ||||||
} | ||||||
|
||||||
// Support scalar float scale only for now. Need to extend to other float types if needed. | ||||||
Initializer scale_initializer(*scale_tensor_proto, graph_viewer.ModelPath()); | ||||||
if (scale_initializer.dims().size() != 0 || scale_initializer.data_type() != ONNX_NAMESPACE::TensorProto::FLOAT) { | ||||||
return false; | ||||||
} | ||||||
scale = *scale_initializer.data<float>(); | ||||||
|
||||||
const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = | ||||||
graph_viewer.GetConstantInitializer(q_input_defs[2]->Name(), true); | ||||||
if (!zp_tensor_proto) { | ||||||
return false; | ||||||
} | ||||||
|
||||||
Initializer zp_initializer(*zp_tensor_proto, graph_viewer.ModelPath()); | ||||||
if (zp_initializer.dims().size() != 0) { | ||||||
return false; | ||||||
} | ||||||
|
||||||
data_type = zp_initializer.data_type(); | ||||||
switch (data_type) { | ||||||
case ONNX_NAMESPACE::TensorProto::INT8: | ||||||
zp = static_cast<int32_t>(*zp_initializer.data<int8_t>()); | ||||||
break; | ||||||
case ONNX_NAMESPACE::TensorProto::UINT8: | ||||||
zp = static_cast<int32_t>(*zp_initializer.data<uint8_t>()); | ||||||
break; | ||||||
case ONNX_NAMESPACE::TensorProto::INT16: | ||||||
zp = static_cast<int32_t>(*zp_initializer.data<int16_t>()); | ||||||
break; | ||||||
case ONNX_NAMESPACE::TensorProto::UINT16: | ||||||
zp = static_cast<int32_t>(*zp_initializer.data<uint16_t>()); | ||||||
break; | ||||||
default: | ||||||
return false; | ||||||
} | ||||||
|
||||||
return true; | ||||||
} | ||||||
|
||||||
bool CanRemoveRelu(const GraphViewer& graph_viewer, const Node& q_node) { | ||||||
float scale = 0.0f; | ||||||
int32_t zp = 0; | ||||||
int32_t data_type = 0; | ||||||
if (!GetQSalarScaleZp(graph_viewer, q_node, scale, zp, data_type)) { | ||||||
return false; | ||||||
} | ||||||
|
||||||
int32_t data_type_min = 0; | ||||||
int32_t data_type_max = 0; | ||||||
if (!GetDataTypeMinMax(data_type, data_type_min, data_type_max)) { | ||||||
return false; | ||||||
} | ||||||
|
||||||
// Relu can be removed if the zero-point is set to the smallest quantized value. | ||||||
return zp == data_type_min; | ||||||
} | ||||||
|
||||||
bool CanRemoveClip(const GraphViewer& graph_viewer, const Node& clip_node, const Node& q_node) { | ||||||
float scale = 0.0f; | ||||||
int32_t zp = 0; | ||||||
int32_t data_type = 0; | ||||||
if (!GetQSalarScaleZp(graph_viewer, q_node, scale, zp, data_type)) { | ||||||
return false; | ||||||
} | ||||||
|
||||||
float min = 0.0f; | ||||||
float max = 0.0f; | ||||||
if (!optimizer_utils::GetClipConstantMinMax(graph_viewer.GetGraph(), clip_node, min, max)) { | ||||||
return false; | ||||||
} | ||||||
|
||||||
int32_t q_clip_min = static_cast<int32_t>(::rint(min / scale)) + zp; | ||||||
int32_t q_clip_max = static_cast<int32_t>(::rint(max / scale)) + zp; | ||||||
|
||||||
int32_t data_type_min = 0; | ||||||
int32_t data_type_max = 0; | ||||||
if (!GetDataTypeMinMax(data_type, data_type_min, data_type_max)) { | ||||||
return false; | ||||||
} | ||||||
|
||||||
// The Clip can be removed if its range entirely overlaps the quantization range. | ||||||
// QClip range: [------------------] | ||||||
// Quant range: [-------------] | ||||||
return q_clip_min <= data_type_min && q_clip_max >= data_type_max; | ||||||
} | ||||||
|
||||||
} // namespace | ||||||
|
||||||
bool CanFuseActivationQ(const GraphViewer& graph_viewer, const Node& activation_node, const Node& q_node) { | ||||||
const std::string& activation_op_type = activation_node.OpType(); | ||||||
Check warning on line 346 in onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc GitHub Actions / Optional Lint C++
|
||||||
if (activation_op_type == "Relu") { | ||||||
return CanRemoveRelu(graph_viewer, q_node); | ||||||
} else if (activation_op_type == "Clip") { | ||||||
return CanRemoveClip(graph_viewer, activation_node, q_node); | ||||||
} | ||||||
return false; | ||||||
} | ||||||
|
||||||
} // namespace onnxruntime::QDQ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a comment explaining what the 'activation' node is and how it is expected to be used.
IIUC
If that's correct I'd almost be inclined to call it something like redundant_clip_node (given Relu is a form of Clip).
Also as the OpenVINO EP (IIRC) is doing the fallback to higher precision does it need an update to be aware of the activation node in the NodeUnit?