Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/onnx_ir/passes/common/identity_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@
logger = logging.getLogger(__name__)


def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
def merge_dims(dim1, dim2):
if dim1 == dim2:
return dim1
if not isinstance(dim1, ir.SymbolicDim):
return dim1 # Prefer int value over symbolic dim
if not isinstance(dim2, ir.SymbolicDim):
return dim2
if dim1.value is None:
return dim2
return dim1

if shape1 is None:
return shape2
if shape2 is None:
return shape1
if len(shape1) != len(shape2):
raise ValueError(
f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}."
)
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])


class IdentityEliminationPass(ir.passes.InPlacePass):
"""Pass for eliminating redundant Identity nodes.

Expand Down Expand Up @@ -75,6 +98,11 @@ def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
if output_is_graph_output and input_is_graph_input:
return False

# Copy over shape/type if the output has more complete information
input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
if input_value.type is None:
input_value.type = output_value.type

# Case 1 & 2 (merged): Eliminate the identity node
# Replace all uses of output with input
ir.convenience.replace_all_uses_with(output_value, input_value)
Expand Down