Skip to content

Commit

Permalink
Run components in the correct order
Browse files Browse the repository at this point in the history
  • Loading branch information
Willenbrink committed Jan 2, 2025
1 parent 23412d1 commit 553b2c6
Showing 1 changed file with 55 additions and 20 deletions.
75 changes: 55 additions & 20 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,36 +946,71 @@ def _find_next_runnable_component(
:returns: The name and the instance of the next Component that can be run
"""

def is_runnable(name, comp):
# Check that at least one connection provides a value to the component
# This ensures that components with only default inputs do not run
# unless a value is passed to them
at_least_one_conn = False
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.name not in components_inputs.get(name, {}) and input_socket.is_mandatory:
return False

for sender, receiver, edge_data in self.graph.edges(data=True):
if receiver != name:
continue
receiver_socket = edge_data["to_socket"].name

if receiver_socket in components_inputs.get(name, {}):
at_least_one_conn = True
return at_least_one_conn
def count_inputs(name, comp):
# Count the number of inputs that are provided
# Returns None if a mandatory input is missing, an integer otherwise
# This can ensure that components with only default inputs are not run
num_inputs = 0
for input_socket in comp.__haystack_input__._sockets_dict.values():
if (
input_socket.name not in components_inputs.get(name, {})
and input_socket.is_mandatory
and not input_socket.is_variadic
):
return None
if input_socket.name in components_inputs.get(name, {}):
num_inputs += 1
return num_inputs

waiting_queue = [(name, comp, _is_lazy_variadic(comp), is_runnable(name, comp)) for name, comp in waiting_queue]
waiting_queue = [
(name, comp, _is_lazy_variadic(comp), count_inputs(name, comp)) for name, comp in waiting_queue
]

first_runnable = next(((name, comp) for (name, comp, _, runnable) in waiting_queue if runnable), None)
first_runnable = next(
(
(name, comp)
for (name, comp, lazy_variadic, num_inputs) in waiting_queue
if not lazy_variadic and num_inputs is not None and num_inputs > 0
),
None,
)
if first_runnable:
return first_runnable

first_lazy_variadic = next(
((name, comp) for (name, comp, lazy_variadic, _) in waiting_queue if lazy_variadic), None
(
(name, comp)
for (name, comp, lazy_variadic, num_inputs) in waiting_queue
if lazy_variadic and num_inputs is not None and num_inputs > 0
),
None,
)
if first_lazy_variadic:
return first_lazy_variadic

# Return components that get no input at all but have default arguments
first_default_lazy_variadic = next(
(
(name, comp)
for (name, comp, lazy_variadic, num_inputs) in waiting_queue
if lazy_variadic and num_inputs is not None
),
None,
)
if first_default_lazy_variadic:
return first_default_lazy_variadic

first_default = next(
(
(name, comp)
for (name, comp, lazy_variadic, num_inputs) in waiting_queue
if not lazy_variadic and num_inputs is not None
),
None,
)
if first_default:
return first_default

return None

def _find_next_runnable_lazy_variadic_or_default_component(
Expand Down

0 comments on commit 553b2c6

Please sign in to comment.