From 553b2c652bd7faab861f4e514d7ca9b1134f8091 Mon Sep 17 00:00:00 2001 From: Sebastian Willenbrink Date: Thu, 2 Jan 2025 17:31:21 +0100 Subject: [PATCH] Run components in the correct order --- haystack/core/pipeline/base.py | 75 +++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 3405463250..b9708ff950 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -946,36 +946,71 @@ def _find_next_runnable_component( :returns: The name and the instance of the next Component that can be run """ - def is_runnable(name, comp): - # Check that at least one connection provides a value to the component - # This ensures that components with only default inputs do not run - # unless a value is passed to them - at_least_one_conn = False - for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore - if input_socket.name not in components_inputs.get(name, {}) and input_socket.is_mandatory: - return False - - for sender, receiver, edge_data in self.graph.edges(data=True): - if receiver != name: - continue - receiver_socket = edge_data["to_socket"].name - - if receiver_socket in components_inputs.get(name, {}): - at_least_one_conn = True - return at_least_one_conn + def count_inputs(name, comp): + # Count the number of inputs that are provided + # Returns None if a mandatory input is missing, an integer otherwise + # This can ensure that components with only default inputs are not run + num_inputs = 0 + for input_socket in comp.__haystack_input__._sockets_dict.values(): + if ( + input_socket.name not in components_inputs.get(name, {}) + and input_socket.is_mandatory + and not input_socket.is_variadic + ): + return None + if input_socket.name in components_inputs.get(name, {}): + num_inputs += 1 + return num_inputs - waiting_queue = [(name, comp, _is_lazy_variadic(comp), is_runnable(name, comp)) for name, comp in waiting_queue] + waiting_queue = [ + (name, comp, _is_lazy_variadic(comp), count_inputs(name, comp)) for name, comp in waiting_queue + ] - first_runnable = next(((name, comp) for (name, comp, _, runnable) in waiting_queue if runnable), None) + first_runnable = next( + ( + (name, comp) + for (name, comp, lazy_variadic, num_inputs) in waiting_queue + if not lazy_variadic and num_inputs is not None and num_inputs > 0 + ), + None, + ) if first_runnable: return first_runnable first_lazy_variadic = next( - ((name, comp) for (name, comp, lazy_variadic, _) in waiting_queue if lazy_variadic), None + ( + (name, comp) + for (name, comp, lazy_variadic, num_inputs) in waiting_queue + if lazy_variadic and num_inputs is not None and num_inputs > 0 + ), + None, ) if first_lazy_variadic: return first_lazy_variadic + # Return components that get no input at all but have default arguments + first_default_lazy_variadic = next( + ( + (name, comp) + for (name, comp, lazy_variadic, num_inputs) in waiting_queue + if lazy_variadic and num_inputs is not None + ), + None, + ) + if first_default_lazy_variadic: + return first_default_lazy_variadic + + first_default = next( + ( + (name, comp) + for (name, comp, lazy_variadic, num_inputs) in waiting_queue + if not lazy_variadic and num_inputs is not None + ), + None, + ) + if first_default: + return first_default + return None def _find_next_runnable_lazy_variadic_or_default_component(