Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Optional Activation node to NodeUnit #22888

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::Node

Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,
const Node& target_node,
const Node* p_activation_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes) {
// Within a QDQ node group, a target node input is the only consumer of each DQ.
Expand All @@ -176,6 +177,22 @@ Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,
dq_node->Name(), ", target node: ", target_node.Name());
}

// If activation node is present, currently we require target node has only one output edge, which is connected to
// the activation node. The activation node's output is consumed by the Q node that can be fused with itself.
if (p_activation_node) {
ORT_RETURN_IF_NOT(target_node.GetOutputEdgesCount() == 1 &&
target_node.OutputEdgesBegin()->GetNode().Index() == p_activation_node->Index(),
"QDQ node group cannot have target node with more than one output edge if there is activation "
"node. target node: ",
target_node.Name());
ORT_RETURN_IF_NOT(q_nodes.size() == 1 && p_activation_node->GetOutputEdgesCount() == 1 &&
p_activation_node->OutputEdgesBegin()->GetNode().Index() == q_nodes[0]->Index(),
"QDQ node group cannot have activation node that doesn't have a single output edge to a Q node. "
"activation node: ",
p_activation_node->Name());
return Status::OK();
}

// an output from the target node can have either Q consumers or direct consumers. it cannot have both.
// this must be checked on a per output basis.
// e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ
Expand Down Expand Up @@ -228,6 +245,7 @@ Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,

return Status::OK();
}

NodeUnit::NodeUnit(const Node& node)
: target_node_(node),
type_(Type::SingleNode),
Expand All @@ -238,11 +256,15 @@ NodeUnit::NodeUnit(const Node& node)
NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
: dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
target_node_(*graph_viewer.GetNode(node_group.target_node)),
p_activation_node_(
node_group.activation_node.has_value() ? graph_viewer.GetNode(node_group.activation_node.value()) : nullptr),
q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
type_(Type::QDQGroup),
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_));
outputs_{
GetQDQIODefs((p_activation_node_ ? *p_activation_node_ : target_node_), node_group, false /* is_input */)} {
ORT_THROW_IF_ERROR(
QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, p_activation_node_, dq_nodes_, q_nodes_));

input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0),
[](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); });
Expand All @@ -253,8 +275,10 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g

// create output edges. each target node output either goes to Q node/s or non-Q node/s.
// ValidateNodeGroupQDQNodes ensures this.
auto cur_edge = target_node_.OutputEdgesBegin();
auto end_edge = target_node_.OutputEdgesEnd();
// If activation node is present, the target node has only one output edge, which is connected to the activation node.
const Node& output_producer = p_activation_node_ ? *p_activation_node_ : target_node_;
auto cur_edge = output_producer.OutputEdgesBegin();
auto end_edge = output_producer.OutputEdgesEnd();
for (; cur_edge != end_edge; ++cur_edge) {
const Node& node = cur_edge->GetNode();

Expand All @@ -273,12 +297,13 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g
}
}

NodeUnit::NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node,
NodeUnit::NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node, const Node* p_activation_node,
gsl::span<const Node* const> q_nodes, Type unit_type,
gsl::span<const NodeUnitIODef> inputs, gsl::span<const NodeUnitIODef> outputs,
size_t input_edge_count, Node::EdgeSet output_edges)
: dq_nodes_(dq_nodes.begin(), dq_nodes.end()),
target_node_(target_node),
p_activation_node_(p_activation_node),
q_nodes_(q_nodes.begin(), q_nodes.end()),
type_(unit_type),
inputs_(inputs.begin(), inputs.end()),
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/framework/node_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ struct NodeGroup {
std::vector<NodeIndex> dq_nodes;
std::vector<NodeIndex> q_nodes;
NodeIndex target_node;
std::optional<NodeIndex> activation_node;

// Validator to check if the set of nodes can form a valid QDQ NodeGroup.
// Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to
// be converted into a single node with a quantized operator.
static Status CanCreateNodeGroup(const GraphViewer& graph_viewer,
const Node& target_node,
const Node* p_activation_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes);
};
Expand Down Expand Up @@ -68,7 +70,7 @@ class NodeUnit {
public:
explicit NodeUnit(const Node& node);
explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group);
NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node,
NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node, const Node* p_activation_node,
gsl::span<const Node* const> q_nodes, Type unit_type,
gsl::span<const NodeUnitIODef> inputs, gsl::span<const NodeUnitIODef> outputs,
size_t input_edge_count, Node::EdgeSet output_edges);
Expand All @@ -87,6 +89,7 @@ class NodeUnit {
ProviderType GetExecutionProviderType() const noexcept;

const Node& GetNode() const noexcept { return target_node_; }
const Node* GetActivationNode() const noexcept { return p_activation_node_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a comment explaining what the 'activation' node is and how it is expected to be used.

IIUC

  • if you're using the QDQ node unit for the quantized version of the target node the activation node can be ignored as it's made redundant by the values of the Q node
    • so it's not really a 'fusion' per se as we're not combining the values of the Clip/Relu with the Q, we're ignoring it
  • if you are falling back to higher precision and dropping the DQ/Q nodes, you need to keep both the target node and activation node if present

If that's correct I'd almost be inclined to call it something like redundant_clip_node (given Relu is a form of Clip).

Also as the OpenVINO EP (IIRC) is doing the fallback to higher precision does it need an update to be aware of the activation node in the NodeUnit?

const std::vector<const Node*>& GetDQNodes() const noexcept { return dq_nodes_; }
const std::vector<const Node*>& GetQNodes() const noexcept { return q_nodes_; }
std::vector<const Node*> GetAllNodesInGroup() const noexcept;
Expand All @@ -106,6 +109,7 @@ class NodeUnit {

const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs
const Node& target_node_;
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present.
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs
Comment on lines 111 to 113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const Node& target_node_;
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present.
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs
const Node& target_node_;
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present.
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs

const Type type_;

Expand Down
137 changes: 137 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,141 @@

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

namespace {

bool GetDataTypeMinMax(int32_t data_type, int32_t& min, int32_t& max) {
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::INT8:
min = static_cast<int32_t>(std::numeric_limits<int8_t>::min());
max = static_cast<int32_t>(std::numeric_limits<int8_t>::max());
break;
case ONNX_NAMESPACE::TensorProto::UINT8:
min = static_cast<int32_t>(std::numeric_limits<uint8_t>::min());
max = static_cast<int32_t>(std::numeric_limits<uint8_t>::max());
break;
case ONNX_NAMESPACE::TensorProto::INT16:
min = static_cast<int32_t>(std::numeric_limits<int16_t>::min());
max = static_cast<int32_t>(std::numeric_limits<int16_t>::max());
break;
case ONNX_NAMESPACE::TensorProto::UINT16:
min = static_cast<int32_t>(std::numeric_limits<uint16_t>::min());
max = static_cast<int32_t>(std::numeric_limits<uint16_t>::max());
break;
default:
return false;
}
return true;
}
bool GetQSalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool GetQSalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp,
bool GetQScalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp,

Would be good to refactor some of the utils here as there seems to be a fair bit of duplication.

e.g. maybe a general purpose helper that reads the scale and zp values (scalar or otherwise), and has a bool to indicate if they're scalar. that helper could be used by many of the utils here.

int32_t& data_type) {
assert(q_node.OpType() == QOpName);
const auto& q_input_defs = q_node.InputDefs();
if (q_input_defs.size() != 3 || !q_input_defs[2]->Exists()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zp is optional and defaults to zero, so do we need to require 3 inputs here?

return false;
}

const ONNX_NAMESPACE::TensorProto* scale_tensor_proto =
graph_viewer.GetConstantInitializer(q_input_defs[1]->Name(), true);
if (!scale_tensor_proto) {
return false;
}

// Support scalar float scale only for now. Need to extend to other float types if needed.
Initializer scale_initializer(*scale_tensor_proto, graph_viewer.ModelPath());
if (scale_initializer.dims().size() != 0 || scale_initializer.data_type() != ONNX_NAMESPACE::TensorProto::FLOAT) {
return false;
}
scale = *scale_initializer.data<float>();

const ONNX_NAMESPACE::TensorProto* zp_tensor_proto =
graph_viewer.GetConstantInitializer(q_input_defs[2]->Name(), true);
if (!zp_tensor_proto) {
return false;
}

Initializer zp_initializer(*zp_tensor_proto, graph_viewer.ModelPath());
if (zp_initializer.dims().size() != 0) {
return false;
}

data_type = zp_initializer.data_type();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::INT8:
zp = static_cast<int32_t>(*zp_initializer.data<int8_t>());
break;
case ONNX_NAMESPACE::TensorProto::UINT8:
zp = static_cast<int32_t>(*zp_initializer.data<uint8_t>());
break;
case ONNX_NAMESPACE::TensorProto::INT16:
zp = static_cast<int32_t>(*zp_initializer.data<int16_t>());
break;
case ONNX_NAMESPACE::TensorProto::UINT16:
zp = static_cast<int32_t>(*zp_initializer.data<uint16_t>());
break;
default:
return false;
}

return true;
}

bool CanRemoveRelu(const GraphViewer& graph_viewer, const Node& q_node) {
float scale = 0.0f;
int32_t zp = 0;
int32_t data_type = 0;
if (!GetQSalarScaleZp(graph_viewer, q_node, scale, zp, data_type)) {
return false;
}

int32_t data_type_min = 0;
int32_t data_type_max = 0;
if (!GetDataTypeMinMax(data_type, data_type_min, data_type_max)) {
return false;
}

// Relu can be removed if the zero-point is set to the smallest quantized value.
return zp == data_type_min;
}

bool CanRemoveClip(const GraphViewer& graph_viewer, const Node& clip_node, const Node& q_node) {
float scale = 0.0f;
int32_t zp = 0;
int32_t data_type = 0;
if (!GetQSalarScaleZp(graph_viewer, q_node, scale, zp, data_type)) {
return false;
}

float min = 0.0f;
float max = 0.0f;
if (!optimizer_utils::GetClipConstantMinMax(graph_viewer.GetGraph(), clip_node, min, max)) {
return false;
}

int32_t q_clip_min = static_cast<int32_t>(::rint(min / scale)) + zp;
int32_t q_clip_max = static_cast<int32_t>(::rint(max / scale)) + zp;

int32_t data_type_min = 0;
int32_t data_type_max = 0;
if (!GetDataTypeMinMax(data_type, data_type_min, data_type_max)) {
return false;
}

// The Clip can be removed if its range entirely overlaps the quantization range.
// QClip range: [------------------]
// Quant range: [-------------]
return q_clip_min <= data_type_min && q_clip_max >= data_type_max;
}

} // namespace

bool CanFuseActivationQ(const GraphViewer& graph_viewer, const Node& activation_node, const Node& q_node) {
const std::string& activation_op_type = activation_node.OpType();

Check warning on line 346 in onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc:346: Add #include <string> for string [build/include_what_you_use] [4]
if (activation_op_type == "Relu") {
return CanRemoveRelu(graph_viewer, q_node);
} else if (activation_op_type == "Clip") {
return CanRemoveClip(graph_viewer, activation_node, q_node);
}
return false;
}

} // namespace onnxruntime::QDQ
5 changes: 5 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace onnxruntime {

class Node;
class Path;
class GraphViewer;

namespace QDQ {

Expand Down Expand Up @@ -76,5 +77,9 @@ bool MatchQNode(const Node& node);
// Check DQ node op type, version, and domain.
bool MatchDQNode(const Node& node);
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

// Check if an activation node can be fused with a Q node.
bool CanFuseActivationQ(const GraphViewer& graph_viewer, const Node& activation_node, const Node& q_node);

} // namespace QDQ
} // namespace onnxruntime
Loading
Loading