From 60613fe1a789c267437162e33855bf2967b7c1ab Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 24 Jun 2024 22:37:53 -0400 Subject: [PATCH 1/5] use set for unreachable nodes --- src/deepsparse/utils/extractor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/deepsparse/utils/extractor.py b/src/deepsparse/utils/extractor.py index b113431a02..0ccb4b10b3 100644 --- a/src/deepsparse/utils/extractor.py +++ b/src/deepsparse/utils/extractor.py @@ -21,7 +21,7 @@ """ import os -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple, Set import onnx.helper import onnx.shape_inference @@ -84,8 +84,8 @@ def _collect_new_outputs(self, names: List[str]) -> List[ValueInfoProto]: def _dfs_search_reachable_nodes( self, node_output_name: str, - graph_input_names: List[str], - reachable_nodes: List[NodeProto], + graph_input_names: Set[str], + reachable_nodes: Set[NodeProto], ) -> None: if node_output_name in graph_input_names: return @@ -95,7 +95,7 @@ def _dfs_search_reachable_nodes( continue if node in reachable_nodes: continue - reachable_nodes.append(node) + reachable_nodes.add(node) for name in node.input: self._dfs_search_reachable_nodes( name, graph_input_names, reachable_nodes @@ -106,9 +106,9 @@ def _collect_reachable_nodes( input_names: List[str], output_names: List[str], ) -> List[NodeProto]: - reachable_nodes = list() # type: ignore + reachable_nodes = set() # type: ignore for name in output_names: - self._dfs_search_reachable_nodes(name, input_names, reachable_nodes) + self._dfs_search_reachable_nodes(name, set(input_names), reachable_nodes) # needs to be topology sorted. nodes = [n for n in self.graph.node if n in reachable_nodes] return nodes From 4e40524338093e6b41145bc2d495a517b2b01e5b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 24 Jun 2024 22:45:13 -0400 Subject: [PATCH 2/5] use set for unreachable nodes to further cut runtime --- src/deepsparse/utils/extractor.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/deepsparse/utils/extractor.py b/src/deepsparse/utils/extractor.py index 0ccb4b10b3..e3253ac7ce 100644 --- a/src/deepsparse/utils/extractor.py +++ b/src/deepsparse/utils/extractor.py @@ -86,19 +86,25 @@ def _dfs_search_reachable_nodes( node_output_name: str, graph_input_names: Set[str], reachable_nodes: Set[NodeProto], + unreachable_nodes: Set[NodeProto] ) -> None: if node_output_name in graph_input_names: return - for node in self.graph.node: - # check output_name first to reduce run time - if node_output_name not in node.output: - continue - if node in reachable_nodes: - continue + + nodes_to_search = [ + node + for node in unreachable_nodes + if node_output_name in node.output and node not in reachable_nodes + ] + + for node in nodes_to_search: reachable_nodes.add(node) + unreachable_nodes.remove(node) + + for node in nodes_to_search: for name in node.input: self._dfs_search_reachable_nodes( - name, graph_input_names, reachable_nodes + name, graph_input_names, reachable_nodes, unreachable_nodes ) def _collect_reachable_nodes( @@ -106,9 +112,13 @@ def _collect_reachable_nodes( input_names: List[str], output_names: List[str], ) -> List[NodeProto]: + input_names = set(input_names) reachable_nodes = set() # type: ignore + unreachable_nodes = set(self.graph.node) # type: ignore for name in output_names: - self._dfs_search_reachable_nodes(name, set(input_names), reachable_nodes) + self._dfs_search_reachable_nodes( + name, input_names, reachable_nodes, unreachable_nodes + ) # needs to be topology sorted. nodes = [n for n in self.graph.node if n in reachable_nodes] return nodes From c0f478b95332e99a75e7907f01c802a2d66ab771 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 24 Jun 2024 22:55:01 -0400 Subject: [PATCH 3/5] remove unneeded check --- src/deepsparse/utils/extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deepsparse/utils/extractor.py b/src/deepsparse/utils/extractor.py index e3253ac7ce..ddb64d8e01 100644 --- a/src/deepsparse/utils/extractor.py +++ b/src/deepsparse/utils/extractor.py @@ -94,7 +94,7 @@ def _dfs_search_reachable_nodes( nodes_to_search = [ node for node in unreachable_nodes - if node_output_name in node.output and node not in reachable_nodes + if node_output_name in node.output ] for node in nodes_to_search: From 45e2238ed3adfe5fcf75b320768bb6969cf7e963 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 30 Jun 2024 23:59:43 -0400 Subject: [PATCH 4/5] index into node list to avoid hashing NodeProtos --- src/deepsparse/utils/extractor.py | 37 ++++++++++++++++--------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/deepsparse/utils/extractor.py b/src/deepsparse/utils/extractor.py index ddb64d8e01..0b9979ab80 100644 --- a/src/deepsparse/utils/extractor.py +++ b/src/deepsparse/utils/extractor.py @@ -21,7 +21,7 @@ """ import os -from typing import Any, List, Optional, Sequence, Tuple, Set +from typing import Any, List, Optional, Sequence, Set, Tuple import onnx.helper import onnx.shape_inference @@ -85,26 +85,25 @@ def _dfs_search_reachable_nodes( self, node_output_name: str, graph_input_names: Set[str], - reachable_nodes: Set[NodeProto], - unreachable_nodes: Set[NodeProto] + nodes: List[NodeProto], + reachable: Set[int], + unreachable: Set[int], ) -> None: if node_output_name in graph_input_names: return - + nodes_to_search = [ - node - for node in unreachable_nodes - if node_output_name in node.output + index for index in unreachable if node_output_name in nodes[index].output ] - for node in nodes_to_search: - reachable_nodes.add(node) - unreachable_nodes.remove(node) + for node_index in nodes_to_search: + reachable.add(node_index) + unreachable.remove(node_index) - for node in nodes_to_search: - for name in node.input: + for node_index in nodes_to_search: + for name in nodes[node_index].input: self._dfs_search_reachable_nodes( - name, graph_input_names, reachable_nodes, unreachable_nodes + name, graph_input_names, nodes, reachable, unreachable ) def _collect_reachable_nodes( @@ -113,14 +112,16 @@ def _collect_reachable_nodes( output_names: List[str], ) -> List[NodeProto]: input_names = set(input_names) - reachable_nodes = set() # type: ignore - unreachable_nodes = set(self.graph.node) # type: ignore + nodes = [node for node in self.graph.node] + reachable = set() + unreachable = set(range(len(nodes))) for name in output_names: self._dfs_search_reachable_nodes( - name, input_names, reachable_nodes, unreachable_nodes + name, input_names, nodes, reachable, unreachable ) - # needs to be topology sorted. - nodes = [n for n in self.graph.node if n in reachable_nodes] + # needs to be topologically sorted + reachable = sorted(list(reachable)) + nodes = [nodes[node_index] for node_index in reachable] return nodes def _collect_referred_local_functions( From 1d121930e58795509dd83798b0c570e8375db019 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 5 Jul 2024 17:49:43 -0400 Subject: [PATCH 5/5] add comments --- src/deepsparse/utils/extractor.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/deepsparse/utils/extractor.py b/src/deepsparse/utils/extractor.py index 0b9979ab80..3361f56b86 100644 --- a/src/deepsparse/utils/extractor.py +++ b/src/deepsparse/utils/extractor.py @@ -89,17 +89,30 @@ def _dfs_search_reachable_nodes( reachable: Set[int], unreachable: Set[int], ) -> None: + """ + Helper function to find nodes which are connected to an output + + :param node_output_name: The name of the output + :param graph_input_names: The names of all inputs of the graph + :param nodes: The list of all nodes of the graph + :param reachable: The set of indexes to reachable nodes in `nodes` + :param unreachable: The set of indexes to unreachable nodes in `nodes` + """ + # finish search at inputs if node_output_name in graph_input_names: return + # find nodes connected to this output nodes_to_search = [ index for index in unreachable if node_output_name in nodes[index].output ] + # add nodes connected to this output to sets for node_index in nodes_to_search: reachable.add(node_index) unreachable.remove(node_index) + # recurse on inputs for node_index in nodes_to_search: for name in nodes[node_index].input: self._dfs_search_reachable_nodes( @@ -110,18 +123,17 @@ def _collect_reachable_nodes( self, input_names: List[str], output_names: List[str], - ) -> List[NodeProto]: - input_names = set(input_names) - nodes = [node for node in self.graph.node] - reachable = set() - unreachable = set(range(len(nodes))) + ) -> list[NodeProto]: + _input_names = set(input_names) + nodes = list(self.graph.node) + reachable: Set[int] = set() + unreachable: Set[int] = set(range(len(nodes))) for name in output_names: self._dfs_search_reachable_nodes( - name, input_names, nodes, reachable, unreachable + name, _input_names, nodes, reachable, unreachable ) # needs to be topologically sorted - reachable = sorted(list(reachable)) - nodes = [nodes[node_index] for node_index in reachable] + nodes = [nodes[node_index] for node_index in sorted(reachable)] return nodes def _collect_referred_local_functions(