Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevents exception when the pipeline contains multiple nested loops #8677

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,27 +1209,32 @@ def _break_supported_cycles_in_graph(self) -> Tuple[networkx.MultiDiGraph, Dict[
# sender_comp will be the last element of cycle and receiver_comp will be the first.
# So if cycle is [1, 2, 3, 4] we would call zip([1, 2, 3, 4], [2, 3, 4, 1]).
for sender_comp, receiver_comp in zip(cycle, cycle[1:] + cycle[:1]):
# We get the key and iterate those as we want to edit the graph data while
# iterating the edges and that would raise.
# Even though the connection key set in Pipeline.connect() uses only the
# sockets name we don't have clashes since it's only used to differentiate
# multiple edges between two nodes.
edge_keys = list(temp_graph.get_edge_data(sender_comp, receiver_comp).keys())
for edge_key in edge_keys:
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key]
receiver_socket = edge_data["to_socket"]
if not receiver_socket.is_variadic and receiver_socket.is_mandatory:
continue

# We found a breakable edge
sender_socket = edge_data["from_socket"]
edges_removed[sender_comp].append(sender_socket.name)
temp_graph.remove_edge(sender_comp, receiver_comp, edge_key)

graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph)
if not graph_has_cycles:
# We removed all the cycles, we can stop
break
# for graphs with multiple nested cycles, we need to check if the edge hasn't
# been previously removed before we try to remove it again
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)
if edge_data is not None:
# We get the key and iterate those as we want to edit the graph data while
# iterating the edges and that would raise.
# Even though the connection key set in Pipeline.connect() uses only the
# sockets name we don't have clashes since it's only used to differentiate
# multiple edges between two nodes.
edge_keys = list(edge_data.keys())

for edge_key in edge_keys:
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key]
receiver_socket = edge_data["to_socket"]
if not receiver_socket.is_variadic and receiver_socket.is_mandatory:
continue

# We found a breakable edge
sender_socket = edge_data["from_socket"]
edges_removed[sender_comp].append(sender_socket.name)
temp_graph.remove_edge(sender_comp, receiver_comp, edge_key)

graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph)
if not graph_has_cycles:
# We removed all the cycles, we can stop
break

if not graph_has_cycles:
# We removed all the cycles, nice
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Prevents the pipeline from raising an exception when there are multiple nested cycles in the graph.
25 changes: 25 additions & 0 deletions test/core/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,3 +1581,28 @@ def test__find_receivers_from(self):
),
)
]

def test__break_supported_cycles_in_graph(self):
# the following pipeline has a nested cycle, which is supported by Haystack
# but was causing an exception to be raised in the _break_supported_cycles_in_graph method
comp1 = component_class("Comp1", input_types={"value": int}, output_types={"value": int})()
comp2 = component_class("Comp2", input_types={"value": Variadic[int]}, output_types={"value": int})()
comp3 = component_class("Comp3", input_types={"value": Variadic[int]}, output_types={"value": int})()
comp4 = component_class("Comp4", input_types={"value": Optional[int]}, output_types={"value": int})()
comp5 = component_class("Comp5", input_types={"value": Variadic[int]}, output_types={"value": int})()
pipe = Pipeline()
pipe.add_component("comp1", comp1)
pipe.add_component("comp2", comp2)
pipe.add_component("comp3", comp3)
pipe.add_component("comp4", comp4)
pipe.add_component("comp5", comp5)
pipe.connect("comp1.value", "comp2.value")
pipe.connect("comp2.value", "comp3.value")
pipe.connect("comp3.value", "comp4.value")
pipe.connect("comp3.value", "comp5.value")
pipe.connect("comp4.value", "comp5.value")
pipe.connect("comp4.value", "comp3.value")
pipe.connect("comp5.value", "comp2.value")

# the following call should not raise an exception
pipe._break_supported_cycles_in_graph()
Loading