diff --git a/haystack/__init__.py b/haystack/__init__.py index 2f275d3a46..7d375c9a4a 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -6,7 +6,7 @@ import haystack.tracing from haystack.core.component import component from haystack.core.errors import ComponentError, DeserializationError -from haystack.core.pipeline import Pipeline, PredefinedPipeline +from haystack.core.pipeline import AsyncPipeline, Pipeline, PredefinedPipeline from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses import Answer, Document, ExtractedAnswer, GeneratedAnswer @@ -18,15 +18,16 @@ haystack.tracing.auto_enable_tracing() __all__ = [ - "component", - "default_from_dict", - "default_to_dict", - "DeserializationError", + "Answer", + "AsyncPipeline", "ComponentError", - "Pipeline", - "PredefinedPipeline", + "DeserializationError", "Document", - "Answer", - "GeneratedAnswer", "ExtractedAnswer", + "GeneratedAnswer", + "Pipeline", + "PredefinedPipeline", + "component", + "default_from_dict", + "default_to_dict", ] diff --git a/haystack/core/pipeline/__init__.py b/haystack/core/pipeline/__init__.py index 744c282f8c..1d84ac1c74 100644 --- a/haystack/core/pipeline/__init__.py +++ b/haystack/core/pipeline/__init__.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from .async_pipeline import AsyncPipeline from .pipeline import Pipeline from .template import PredefinedPipeline -__all__ = ["Pipeline", "PredefinedPipeline"] +__all__ = ["AsyncPipeline", "Pipeline", "PredefinedPipeline"] diff --git a/haystack/core/pipeline/async_pipeline.py b/haystack/core/pipeline/async_pipeline.py new file mode 100644 index 0000000000..da2c21414a --- /dev/null +++ b/haystack/core/pipeline/async_pipeline.py @@ -0,0 +1,535 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from copy import deepcopy +from typing import Any, AsyncIterator, Dict, List, Optional, Set + +from haystack import logging, tracing +from haystack.core.component import Component +from haystack.core.errors import PipelineMaxComponentRuns, PipelineRuntimeError +from haystack.core.pipeline.base import ComponentPriority, PipelineBase +from haystack.telemetry import pipeline_running + +logger = logging.getLogger(__name__) + + +class AsyncPipeline(PipelineBase): + """ + Asynchronous version of the orchestration engine. + + Orchestrates component execution and runs components concurrently if the execution graph allows it. + """ + + async def run_async_generator( # noqa: PLR0915,C901 + self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None, concurrency_limit: int = 4 + ) -> AsyncIterator[Dict[str, Any]]: + """ + Execute this pipeline asynchronously, yielding partial outputs when any component finishes. + + :param data: Initial input data to the pipeline. + :param concurrency_limit: The maximum number of components that are allowed to run concurrently. + :param include_outputs_from: + Set of component names whose individual outputs are to be + included in the pipeline's output. For components that are + invoked multiple times (in a loop), only the last-produced + output is included. + :return: An async iterator of partial (and final) outputs. + """ + if include_outputs_from is None: + include_outputs_from = set() + + # 0) Basic pipeline init + pipeline_running(self) # telemetry + self.warm_up() # optional warm-up (if needed) + + # 1) Prepare ephemeral state + ready_sem = asyncio.Semaphore(max(1, concurrency_limit)) + inputs_state: Dict[str, Dict[str, List[Dict[str, Any]]]] = {} + pipeline_outputs: Dict[str, Any] = {} + running_tasks: Dict[asyncio.Task, str] = {} + + # A set of component names that have been scheduled but not finished: + scheduled_components: Set[str] = set() + + # 2) Convert input data + prepared_data = self._prepare_component_input_data(data) + self._validate_input(prepared_data) + inputs_state = self._convert_to_internal_format(prepared_data) + + # For quick lookup of downstream receivers + ordered_names = sorted(self.graph.nodes.keys()) + cached_receivers = {n: self._find_receivers_from(n) for n in ordered_names} + component_visits = {component_name: 0 for component_name in ordered_names} + + # We fill the queue once and raise if all components are BLOCKED + self.validate_pipeline(self._fill_queue(ordered_names, inputs_state, component_visits)) + + # Single parent span for entire pipeline execution + with tracing.tracer.trace( + "haystack.async_pipeline.run", + tags={ + "haystack.pipeline.input_data": data, + "haystack.pipeline.output_data": pipeline_outputs, + "haystack.pipeline.metadata": self.metadata, + "haystack.pipeline.max_runs_per_component": self._max_runs_per_component, + }, + ) as parent_span: + # ------------------------------------------------- + # We define some functions here so that they have access to local runtime state + # (inputs, tasks, scheduled components) via closures. + # ------------------------------------------------- + async def _run_component_async(component_name: str, component_inputs: Dict[str, Any]) -> Dict[str, Any]: + """ + Runs one component. + + If the component supports async, await directly it will run async; otherwise offload to executor. + Updates visits count, writes outputs to `inputs_state`, + and returns pruned outputs that get stored in `pipeline_outputs`. + + :param component_name: The name of the component. + :param component_inputs: Inputs for the component. + :returns: Outputs from the component that can be yielded from run_async_generator. + """ + if component_visits[component_name] > self._max_runs_per_component: + raise PipelineMaxComponentRuns(f"Max runs for '{component_name}' reached.") + + instance: Component = self.get_component(component_name) + with tracing.tracer.trace( + "haystack.component.run", + tags={ + "haystack.component.name": component_name, + "haystack.component.type": instance.__class__.__name__, + "haystack.component.input_types": {k: type(v).__name__ for k, v in component_inputs.items()}, + "haystack.component.input_spec": { + key: { + "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)), + "senders": value.senders, + } + for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore + }, + "haystack.component.output_spec": { + key: { + "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)), + "receivers": value.receivers, + } + for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore + }, + }, + parent_span=parent_span, + ) as span: + span.set_content_tag("haystack.component.input", deepcopy(component_inputs)) + logger.info("Running component {name}", name=component_name) + + if getattr(instance, "__haystack_supports_async__", False): + outputs = await instance.run_async(**component_inputs) # type: ignore + else: + loop = asyncio.get_running_loop() + outputs = await loop.run_in_executor(None, lambda: instance.run(**component_inputs)) + + component_visits[component_name] += 1 + + if not isinstance(outputs, dict): + raise PipelineRuntimeError( + f"Component '{component_name}' returned an invalid output type. " + f"Expected a dict, but got {type(outputs).__name__} instead. " + ) + + span.set_tag("haystack.component.visits", component_visits[component_name]) + span.set_content_tag("haystack.component.outputs", deepcopy(outputs)) + + # Distribute outputs to downstream inputs; also prune outputs based on `include_outputs_from` + pruned = self._write_component_outputs( + component_name=component_name, + component_outputs=outputs, + inputs=inputs_state, + receivers=cached_receivers[component_name], + include_outputs_from=include_outputs_from, + ) + if pruned: + pipeline_outputs[component_name] = pruned + + return pruned + + async def _run_highest_in_isolation(component_name: str) -> AsyncIterator[Dict[str, Any]]: + """ + Runs a component with HIGHEST priority in isolation. + + We need to run components with HIGHEST priority (i.e. components with GreedyVariadic input socket) + because otherwise, downstream components could produce additional inputs for the GreedyVariadic socket. + + :param component_name: The name of the component. + :return: An async iterator of partial outputs. + """ + # 1) Wait for all in-flight tasks to finish + while running_tasks: + done, _pending = await asyncio.wait(running_tasks.keys(), return_when=asyncio.ALL_COMPLETED) + for finished in done: + finished_component_name = running_tasks.pop(finished) + partial_result = finished.result() + scheduled_components.discard(finished_component_name) + if partial_result: + yield_dict = {finished_component_name: deepcopy(partial_result)} + yield yield_dict # partial outputs + + if component_name in scheduled_components: + # If it's already scheduled for some reason, skip + return + + # 2) Run the HIGHEST component by itself + scheduled_components.add(component_name) + comp_dict = self._get_component_with_graph_metadata_and_visits( + component_name, component_visits[component_name] + ) + component_inputs = self._consume_component_inputs(component_name, comp_dict, inputs_state) + component_inputs = self._add_missing_input_defaults(component_inputs, comp_dict["input_sockets"]) + result = await _run_component_async(component_name, component_inputs) + scheduled_components.remove(component_name) + if result: + yield {component_name: deepcopy(result)} + + async def _schedule_task(component_name: str) -> None: + """ + Schedule a component to run. + + We do NOT wait for it to finish here. This allows us to run other components concurrently. + + :param component_name: The name of the component. + """ + + if component_name in scheduled_components: + return # already scheduled, do nothing + + scheduled_components.add(component_name) + + comp_dict = self._get_component_with_graph_metadata_and_visits( + component_name, component_visits[component_name] + ) + component_inputs = self._consume_component_inputs(component_name, comp_dict, inputs_state) + component_inputs = self._add_missing_input_defaults(component_inputs, comp_dict["input_sockets"]) + + async def _runner(): + async with ready_sem: + result = await _run_component_async(component_name, component_inputs) + + scheduled_components.remove(component_name) + return result + + task = asyncio.create_task(_runner()) + running_tasks[task] = component_name + + async def _wait_for_one_task_to_complete() -> AsyncIterator[Dict[str, Any]]: + """ + Wait for exactly one running task to finish, yield partial outputs. + + If no tasks are running, does nothing. + """ + if running_tasks: + done, _ = await asyncio.wait(running_tasks.keys(), return_when=asyncio.FIRST_COMPLETED) + for finished in done: + finished_component_name = running_tasks.pop(finished) + partial_result = finished.result() + scheduled_components.discard(finished_component_name) + if partial_result: + yield {finished_component_name: deepcopy(partial_result)} + + async def _wait_for_all_tasks_to_complete() -> AsyncIterator[Dict[str, Any]]: + """ + Wait for all running tasks to finish, yield partial outputs. + """ + if running_tasks: + done, _ = await asyncio.wait(running_tasks.keys(), return_when=asyncio.ALL_COMPLETED) + for finished in done: + finished_component_name = running_tasks.pop(finished) + partial_result = finished.result() + scheduled_components.discard(finished_component_name) + if partial_result: + yield {finished_component_name: deepcopy(partial_result)} + + # ------------------------------------------------- + # MAIN SCHEDULING LOOP + # ------------------------------------------------- + while True: + # 2) Build the priority queue of candidates + priority_queue = self._fill_queue(ordered_names, inputs_state, component_visits) + candidate = self._get_next_runnable_component(priority_queue, component_visits) + if candidate is None and running_tasks: + # We need to wait for one task to finish to make progress and potentially unblock the priority_queue + async for partial_result in _wait_for_one_task_to_complete(): + yield partial_result + continue + + if candidate is None and not running_tasks: + # done + break + + priority, component_name, _ = candidate # type: ignore + + if component_name in scheduled_components: + # We need to wait for one task to finish to make progress + async for partial_result in _wait_for_one_task_to_complete(): + yield partial_result + continue + + if priority == ComponentPriority.HIGHEST: + # 1) run alone + async for partial_result in _run_highest_in_isolation(component_name): + yield partial_result + # then continue the loop + continue + + if priority == ComponentPriority.READY: + # 1) schedule this one + await _schedule_task(component_name) + + # 2) Possibly schedule more READY tasks if concurrency not fully used + while len(priority_queue) > 0 and not ready_sem.locked(): + peek_prio, peek_name = priority_queue.peek() + if peek_prio in (ComponentPriority.BLOCKED, ComponentPriority.HIGHEST): + # can't run or must run alone => skip + break + if peek_prio == ComponentPriority.READY: + priority_queue.pop() + await _schedule_task(peek_name) + # keep adding while concurrency is not locked + continue + + # The next is DEFER/DEFER_LAST => we only schedule it if it "becomes READY" + # We'll handle it in the next iteration or with incremental waiting + break + + # We only schedule components with priority DEFER or DEFER_LAST when no other tasks are running + elif priority in (ComponentPriority.DEFER, ComponentPriority.DEFER_LAST) and not running_tasks: + await _schedule_task(component_name) + + # To make progress, we wait for one task to complete before re-starting the loop + async for partial_result in _wait_for_one_task_to_complete(): + yield partial_result + + # End main loop + + # 3) Drain leftover tasks + async for partial_result in _wait_for_all_tasks_to_complete(): + yield partial_result + + # 4) Yield final pipeline outputs + yield deepcopy(pipeline_outputs) + + async def run_async( + self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None, concurrency_limit: int = 4 + ) -> Dict[str, Any]: + """ + Runs the Pipeline with given input data. + + Usage: + ```python + from haystack import Document + from haystack.utils import Secret + from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack.components.retrievers.in_memory import InMemoryBM25Retriever + from haystack.components.generators import OpenAIGenerator + from haystack.components.builders.answer_builder import AnswerBuilder + from haystack.components.builders.prompt_builder import PromptBuilder + + from haystack_experimental import AsyncPipeline + + import asyncio + + # Write documents to InMemoryDocumentStore + document_store = InMemoryDocumentStore() + document_store.write_documents([ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Giorgio and I live in Rome.") + ]) + + prompt_template = \"\"\" + Given these documents, answer the question. + Documents: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + Question: {{question}} + Answer: + \"\"\" + + retriever = InMemoryBM25Retriever(document_store=document_store) + prompt_builder = PromptBuilder(template=prompt_template) + llm = OpenAIGenerator(api_key=Secret.from_token(api_key)) + + rag_pipeline = AsyncPipeline() + rag_pipeline.add_component("retriever", retriever) + rag_pipeline.add_component("prompt_builder", prompt_builder) + rag_pipeline.add_component("llm", llm) + rag_pipeline.connect("retriever", "prompt_builder.documents") + rag_pipeline.connect("prompt_builder", "llm") + + # Ask a question + question = "Who lives in Paris?" + + + async def run_inner(data, include_outputs_from): + return await rag_pipeline.run_async(data=data, include_outputs_from=include_outputs_from) + + data = { + "retriever": {"query": question}, + "prompt_builder": {"question": question}, + } + async_loop = asyncio.new_event_loop() + asyncio.set_event_loop(async_loop) + results = async_loop.run_until_complete(run_inner(data)) + async_loop.close() + + print(results["llm"]["replies"]) + # Jean lives in Paris + ``` + + :param data: + A dictionary of inputs for the pipeline's components. Each key is a component name + and its value is a dictionary of that component's input parameters: + ``` + data = { + "comp1": {"input1": 1, "input2": 2}, + } + ``` + For convenience, this format is also supported when input names are unique: + ``` + data = { + "input1": 1, "input2": 2, + } + ``` + :param include_outputs_from: + Set of component names whose individual outputs are to be + included in the pipeline's output. For components that are + invoked multiple times (in a loop), only the last-produced + output is included. + :param concurrency_limit: The maximum number of components that should be allowed to run concurrently. + :returns: + A dictionary where each entry corresponds to a component name + and its output. If `include_outputs_from` is `None`, this dictionary + will only contain the outputs of leaf components, i.e., components + without outgoing connections. + + :raises ValueError: + If invalid inputs are provided to the pipeline. + :raises PipelineRuntimeError: + If the Pipeline contains cycles with unsupported connections that would cause + it to get stuck and fail running. + Or if a Component fails or returns output in an unsupported type. + :raises PipelineMaxComponentRuns: + If a Component reaches the maximum number of times it can be run in this Pipeline. + """ + final: Dict[str, Any] = {} + async for partial in self.run_async_generator( + data=data, concurrency_limit=concurrency_limit, include_outputs_from=include_outputs_from + ): + final = partial + return final or {} + + def run( + self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None, concurrency_limit: int = 4 + ) -> Dict[str, Any]: + """ + Runs the pipeline with given input data. + + This method is synchronous, but it runs components asynchronously internally. + Check out `run_async` or `run_async_generator` if you are looking for async-methods. + + Usage: + ```python + from haystack import Document + from haystack.utils import Secret + from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack.components.retrievers.in_memory import InMemoryBM25Retriever + from haystack.components.generators import OpenAIGenerator + from haystack.components.builders.answer_builder import AnswerBuilder + from haystack.components.builders.prompt_builder import PromptBuilder + + from haystack_experimental import AsyncPipeline + + # Write documents to InMemoryDocumentStore + document_store = InMemoryDocumentStore() + document_store.write_documents([ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Giorgio and I live in Rome.") + ]) + + prompt_template = \"\"\" + Given these documents, answer the question. + Documents: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + Question: {{question}} + Answer: + \"\"\" + + retriever = InMemoryBM25Retriever(document_store=document_store) + prompt_builder = PromptBuilder(template=prompt_template) + llm = OpenAIGenerator(api_key=Secret.from_token(api_key)) + + rag_pipeline = AsyncPipeline() + rag_pipeline.add_component("retriever", retriever) + rag_pipeline.add_component("prompt_builder", prompt_builder) + rag_pipeline.add_component("llm", llm) + rag_pipeline.connect("retriever", "prompt_builder.documents") + rag_pipeline.connect("prompt_builder", "llm") + + # Ask a question + question = "Who lives in Paris?" + + + async def run_inner(data, include_outputs_from): + return await rag_pipeline.run_async(data=data, include_outputs_from=include_outputs_from) + + data = { + "retriever": {"query": question}, + "prompt_builder": {"question": question}, + } + + results = rag_pipeline.run(data) + + print(results["llm"]["replies"]) + # Jean lives in Paris + ``` + + :param data: + A dictionary of inputs for the pipeline's components. Each key is a component name + and its value is a dictionary of that component's input parameters: + ``` + data = { + "comp1": {"input1": 1, "input2": 2}, + } + ``` + For convenience, this format is also supported when input names are unique: + ``` + data = { + "input1": 1, "input2": 2, + } + ``` + :param include_outputs_from: + Set of component names whose individual outputs are to be + included in the pipeline's output. For components that are + invoked multiple times (in a loop), only the last-produced + output is included. + :param concurrency_limit: The maximum number of components that should be allowed to run concurrently. + :returns: + A dictionary where each entry corresponds to a component name + and its output. If `include_outputs_from` is `None`, this dictionary + will only contain the outputs of leaf components, i.e., components + without outgoing connections. + + :raises ValueError: + If invalid inputs are provided to the pipeline. + :raises PipelineRuntimeError: + If the Pipeline contains cycles with unsupported connections that would cause + it to get stuck and fail running. + Or if a Component fails or returns output in an unsupported type. + :raises PipelineMaxComponentRuns: + If a Component reaches the maximum number of times it can be run in this Pipeline. + """ + return asyncio.run( + self.run_async(data=data, include_outputs_from=include_outputs_from, concurrency_limit=concurrency_limit) + ) diff --git a/haystack/telemetry/_telemetry.py b/haystack/telemetry/_telemetry.py index f0cfbf269b..4ba3a0c765 100644 --- a/haystack/telemetry/_telemetry.py +++ b/haystack/telemetry/_telemetry.py @@ -8,7 +8,7 @@ import uuid from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import posthog import yaml @@ -18,7 +18,7 @@ from haystack.telemetry._environment import collect_system_specs if TYPE_CHECKING: - from haystack.core.pipeline import Pipeline + from haystack.core.pipeline import AsyncPipeline, Pipeline HAYSTACK_TELEMETRY_ENABLED = "HAYSTACK_TELEMETRY_ENABLED" @@ -135,7 +135,7 @@ def send_telemetry_wrapper(*args, **kwargs): @send_telemetry -def pipeline_running(pipeline: "Pipeline") -> Optional[Tuple[str, Dict[str, Any]]]: +def pipeline_running(pipeline: Union["Pipeline", "AsyncPipeline"]) -> Optional[Tuple[str, Dict[str, Any]]]: """ Collects telemetry data for a pipeline run and sends it to Posthog. @@ -170,6 +170,7 @@ def pipeline_running(pipeline: "Pipeline") -> Optional[Tuple[str, Dict[str, Any] # Data sent to Posthog return "Pipeline run (2.x)", { "pipeline_id": str(id(pipeline)), + "pipeline_type": generate_qualified_class_name(type(pipeline)), "runs": pipeline._telemetry_runs, "components": components, } diff --git a/releasenotes/notes/feat-async-pipeline-338856a142e1318c.yaml b/releasenotes/notes/feat-async-pipeline-338856a142e1318c.yaml new file mode 100644 index 0000000000..323c78c266 --- /dev/null +++ b/releasenotes/notes/feat-async-pipeline-338856a142e1318c.yaml @@ -0,0 +1,8 @@ +--- +highlights: > + We are introducing the `AsyncPipeline`: Supports running pipelines asynchronously. Schedules components concurrently + whenever possible. Leads to major speed improvements for any pipelines that may run workloads in parallel. +features: + - | + Added a new `AsyncPipeline` implementation that allows pipelines to be executed from async code, + supporting concurrent scheduling of pipeline components for faster processing. diff --git a/test/core/pipeline/features/conftest.py b/test/core/pipeline/features/conftest.py index e1e411df1e..d4139968aa 100644 --- a/test/core/pipeline/features/conftest.py +++ b/test/core/pipeline/features/conftest.py @@ -2,14 +2,25 @@ from typing import Tuple, List, Dict, Any, Set, Union from pathlib import Path import re +import pytest +import asyncio from pytest_bdd import when, then, parsers -from haystack import Pipeline +from haystack import Pipeline, AsyncPipeline PIPELINE_NAME_REGEX = re.compile(r"\[(.*)\]") +@pytest.fixture(params=[AsyncPipeline, Pipeline]) +def pipeline_class(request): + """ + A parametrized fixture that will yield AsyncPipeline for one test run + and Pipeline for the next test run. + """ + return request.param + + @dataclass class PipelineRunData: """ @@ -34,6 +45,54 @@ class _PipelineResult: @when("I run the Pipeline", target_fixture="pipeline_result") def run_pipeline( + pipeline_data: Tuple[Union[AsyncPipeline, Pipeline], List[PipelineRunData]], spying_tracer +) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]: + if isinstance(pipeline_data[0], AsyncPipeline): + return run_async_pipeline(pipeline_data, spying_tracer) + else: + return run_sync_pipeline(pipeline_data, spying_tracer) + + +def run_async_pipeline( + pipeline_data: Tuple[Union[AsyncPipeline], List[PipelineRunData]], spying_tracer +) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]: + """ + Attempts to run a pipeline with the given inputs. + `pipeline_data` is a tuple that must contain: + * A Pipeline instance + * The data to run the pipeline with + + If successful returns a tuple of the run outputs and the expected outputs. + In case an exceptions is raised returns that. + """ + pipeline, pipeline_run_data = pipeline_data[0], pipeline_data[1] + + results: List[_PipelineResult] = [] + + async def run_inner(data, include_outputs_from): + """Wrapper function to call pipeline.run_async method with required params.""" + return await pipeline.run_async(data=data.inputs, include_outputs_from=include_outputs_from) + + for data in pipeline_run_data: + try: + outputs = asyncio.run(run_inner(data, data.include_outputs_from)) + + component_calls = { + (span.tags["haystack.component.name"], span.tags["haystack.component.visits"]): span.tags[ + "haystack.component.input" + ] + for span in spying_tracer.spans + if "haystack.component.name" in span.tags and "haystack.component.visits" in span.tags + } + results.append(_PipelineResult(outputs=outputs, component_calls=component_calls)) + spying_tracer.spans.clear() + except Exception as e: + return e + + return [e for e in zip(results, pipeline_run_data)] + + +def run_sync_pipeline( pipeline_data: Tuple[Pipeline, List[PipelineRunData]], spying_tracer ) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]: """ @@ -61,7 +120,6 @@ def run_pipeline( if "haystack.component.name" in span.tags and "haystack.component.visits" in span.tags } results.append(_PipelineResult(outputs=outputs, component_calls=component_calls)) - spying_tracer.spans.clear() except Exception as e: return e diff --git a/test/core/pipeline/features/pipeline_run.feature b/test/core/pipeline/features/pipeline_run.feature index 8724dd3e1e..380459c947 100644 --- a/test/core/pipeline/features/pipeline_run.feature +++ b/test/core/pipeline/features/pipeline_run.feature @@ -3,8 +3,8 @@ Feature: Pipeline running Scenario Outline: Running a correct Pipeline Given a pipeline When I run the Pipeline - Then it should return the expected result - And components are called with the expected inputs + Then components are called with the expected inputs + And it should return the expected result Examples: | kind | diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index 30c415f68d..1aef887ba8 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -5,7 +5,7 @@ from pytest_bdd import scenarios, given import pytest -from haystack import Pipeline, Document, component +from haystack import Document, component from haystack.document_stores.types import DuplicatePolicy from haystack.dataclasses import ChatMessage, GeneratedAnswer, TextContent, ByteStream from haystack.components.routers import ConditionalRouter, FileTypeRouter @@ -36,22 +36,22 @@ from test.core.pipeline.features.conftest import PipelineRunData -pytestmark = pytest.mark.integration +pytestmark = [pytest.mark.usefixtures("pipeline_class"), pytest.mark.integration] scenarios("pipeline_run.feature") @given("a pipeline that has no components", target_fixture="pipeline_data") -def pipeline_that_has_no_components(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_has_no_components(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) inputs = {} expected_outputs = {} return pipeline, [PipelineRunData(inputs=inputs, expected_outputs=expected_outputs)] @given("a pipeline that is linear", target_fixture="pipeline_data") -def pipeline_that_is_linear(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_is_linear(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("first_addition", AddFixedValue(add=2)) pipeline.add_component("second_addition", AddFixedValue()) pipeline.add_component("double", Double()) @@ -75,7 +75,7 @@ def pipeline_that_is_linear(): @given("a pipeline that has an infinite loop", target_fixture="pipeline_data") -def pipeline_that_has_an_infinite_loop(): +def pipeline_that_has_an_infinite_loop(pipeline_class): routes = [ {"condition": "{{number > 2}}", "output": "{{number}}", "output_name": "big_number", "output_type": int}, {"condition": "{{number <= 2}}", "output": "{{number + 2}}", "output_name": "small_number", "output_type": int}, @@ -85,7 +85,7 @@ def pipeline_that_has_an_infinite_loop(): first_router = ConditionalRouter(routes=routes) second_router = ConditionalRouter(routes=routes) - pipe = Pipeline(max_runs_per_component=1) + pipe = pipeline_class(max_runs_per_component=1) pipe.add_component("main_input", main_input) pipe.add_component("first_router", first_router) pipe.add_component("second_router", second_router) @@ -98,8 +98,8 @@ def pipeline_that_has_an_infinite_loop(): @given("a pipeline that is really complex with lots of components, forks, and loops", target_fixture="pipeline_data") -def pipeline_complex(): - pipeline = Pipeline(max_runs_per_component=2) +def pipeline_complex(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=2) pipeline.add_component("greet_first", Greet(message="Hello, the value is {value}.")) pipeline.add_component("accumulate_1", Accumulate()) pipeline.add_component("add_two", AddFixedValue(add=2)) @@ -192,14 +192,14 @@ def pipeline_complex(): @given("a pipeline that has a single component with a default input", target_fixture="pipeline_data") -def pipeline_that_has_a_single_component_with_a_default_input(): +def pipeline_that_has_a_single_component_with_a_default_input(pipeline_class): @component class WithDefault: @component.output_types(b=int) def run(self, a: int, b: int = 2): return {"c": a + b} - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("with_defaults", WithDefault()) return ( @@ -220,8 +220,8 @@ def run(self, a: int, b: int = 2): @given("a pipeline that has two loops of identical lengths", target_fixture="pipeline_data") -def pipeline_that_has_two_loops_of_identical_lengths(): - pipeline = Pipeline(max_runs_per_component=10) +def pipeline_that_has_two_loops_of_identical_lengths(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=10) pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) pipeline.add_component("remainder", Remainder(divisor=3)) pipeline.add_component("add_one", AddFixedValue(add=1)) @@ -277,8 +277,8 @@ def pipeline_that_has_two_loops_of_identical_lengths(): @given("a pipeline that has two loops of different lengths", target_fixture="pipeline_data") -def pipeline_that_has_two_loops_of_different_lengths(): - pipeline = Pipeline(max_runs_per_component=10) +def pipeline_that_has_two_loops_of_different_lengths(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=10) pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) pipeline.add_component("remainder", Remainder(divisor=3)) pipeline.add_component("add_one", AddFixedValue(add=1)) @@ -338,9 +338,9 @@ def pipeline_that_has_two_loops_of_different_lengths(): @given("a pipeline that has a single loop with two conditional branches", target_fixture="pipeline_data") -def pipeline_that_has_a_single_loop_with_two_conditional_branches(): +def pipeline_that_has_a_single_loop_with_two_conditional_branches(pipeline_class): accumulator = Accumulate() - pipeline = Pipeline(max_runs_per_component=10) + pipeline = pipeline_class(max_runs_per_component=10) pipeline.add_component("add_one", AddFixedValue(add=1)) pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) @@ -386,8 +386,8 @@ def pipeline_that_has_a_single_loop_with_two_conditional_branches(): @given("a pipeline that has a component with dynamic inputs defined in init", target_fixture="pipeline_data") -def pipeline_that_has_a_component_with_dynamic_inputs_defined_in_init(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_has_a_component_with_dynamic_inputs_defined_in_init(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("hello", Hello()) pipeline.add_component("fstring", FString(template="This is the greeting: {greeting}!", variables=["greeting"])) pipeline.add_component("splitter", TextSplitter()) @@ -420,8 +420,8 @@ def pipeline_that_has_a_component_with_dynamic_inputs_defined_in_init(): @given("a pipeline that has two branches that don't merge", target_fixture="pipeline_data") -def pipeline_that_has_two_branches_that_dont_merge(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_has_two_branches_that_dont_merge(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("add_one", AddFixedValue(add=1)) pipeline.add_component("parity", Parity()) pipeline.add_component("add_ten", AddFixedValue(add=10)) @@ -460,8 +460,8 @@ def pipeline_that_has_two_branches_that_dont_merge(): @given("a pipeline that has three branches that don't merge", target_fixture="pipeline_data") -def pipeline_that_has_three_branches_that_dont_merge(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_has_three_branches_that_dont_merge(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("add_one", AddFixedValue(add=1)) pipeline.add_component("repeat", Repeat(outputs=["first", "second"])) pipeline.add_component("add_ten", AddFixedValue(add=10)) @@ -495,8 +495,8 @@ def pipeline_that_has_three_branches_that_dont_merge(): @given("a pipeline that has two branches that merge", target_fixture="pipeline_data") -def pipeline_that_has_two_branches_that_merge(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_has_two_branches_that_merge(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("first_addition", AddFixedValue(add=2)) pipeline.add_component("second_addition", AddFixedValue(add=2)) pipeline.add_component("third_addition", AddFixedValue(add=2)) @@ -528,8 +528,8 @@ def pipeline_that_has_two_branches_that_merge(): @given( "a pipeline that has different combinations of branches that merge and do not merge", target_fixture="pipeline_data" ) -def pipeline_that_has_different_combinations_of_branches_that_merge_and_do_not_merge(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_has_different_combinations_of_branches_that_merge_and_do_not_merge(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("add_one", AddFixedValue()) pipeline.add_component("parity", Parity()) pipeline.add_component("add_ten", AddFixedValue(add=10)) @@ -578,8 +578,8 @@ def pipeline_that_has_different_combinations_of_branches_that_merge_and_do_not_m @given("a pipeline that has two branches, one of which loops back", target_fixture="pipeline_data") -def pipeline_that_has_two_branches_one_of_which_loops_back(): - pipeline = Pipeline(max_runs_per_component=10) +def pipeline_that_has_two_branches_one_of_which_loops_back(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=10) pipeline.add_component("add_zero", AddFixedValue(add=0)) pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) pipeline.add_component("sum", Sum()) @@ -623,7 +623,7 @@ def pipeline_that_has_two_branches_one_of_which_loops_back(): @given("a pipeline that has a component with mutable input", target_fixture="pipeline_data") -def pipeline_that_has_a_component_with_mutable_input(): +def pipeline_that_has_a_component_with_mutable_input(pipeline_class): @component class InputMangler: @component.output_types(mangled_list=List[str]) @@ -631,7 +631,7 @@ def run(self, input_list: List[str]): input_list.append("extra_item") return {"mangled_list": input_list} - pipe = Pipeline(max_runs_per_component=1) + pipe = pipeline_class(max_runs_per_component=1) pipe.add_component("mangler1", InputMangler()) pipe.add_component("mangler2", InputMangler()) pipe.add_component("concat1", StringListJoiner()) @@ -662,7 +662,7 @@ def run(self, input_list: List[str]): @given("a pipeline that has a component with mutable output sent to multiple inputs", target_fixture="pipeline_data") -def pipeline_that_has_a_component_with_mutable_output_sent_to_multiple_inputs(): +def pipeline_that_has_a_component_with_mutable_output_sent_to_multiple_inputs(pipeline_class): @component class PassThroughPromptBuilder: # This is a pass-through component that returns the same input @@ -688,7 +688,7 @@ def run(self, messages: List[ChatMessage]): mm1 = MessageMerger() mm2 = MessageMerger() - pipe = Pipeline(max_runs_per_component=1) + pipe = pipeline_class(max_runs_per_component=1) pipe.add_component("prompt_builder", prompt_builder) pipe.add_component("llm", llm) pipe.add_component("mm1", mm1) @@ -792,7 +792,7 @@ def run(self, messages: List[ChatMessage]): "a pipeline that has a greedy and variadic component after a component with default input", target_fixture="pipeline_data", ) -def pipeline_that_has_a_greedy_and_variadic_component_after_a_component_with_default_input(): +def pipeline_that_has_a_greedy_and_variadic_component_after_a_component_with_default_input(pipeline_class): """ This test verifies that `Pipeline.run()` executes the components in the correct order when there's a greedy Component with variadic input right before a Component with at least one default input. @@ -805,7 +805,7 @@ def pipeline_that_has_a_greedy_and_variadic_component_after_a_component_with_def document_store = InMemoryDocumentStore() document_store.write_documents([Document(content="This is a simple document")]) - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) template = "Given this documents: {{ documents|join(', ', attribute='content') }} Answer this question: {{ query }}" pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store)) pipeline.add_component("prompt_builder", PromptBuilder(template=template)) @@ -868,7 +868,7 @@ def pipeline_that_has_a_greedy_and_variadic_component_after_a_component_with_def @given("a pipeline that has a component that doesn't return a dictionary", target_fixture="pipeline_data") -def pipeline_that_has_a_component_that_doesnt_return_a_dictionary(): +def pipeline_that_has_a_component_that_doesnt_return_a_dictionary(pipeline_class): BrokenComponent = component_class( "BrokenComponent", input_types={"a": int}, @@ -876,13 +876,13 @@ def pipeline_that_has_a_component_that_doesnt_return_a_dictionary(): output=1, # type:ignore ) - pipe = Pipeline(max_runs_per_component=10) + pipe = pipeline_class(max_runs_per_component=10) pipe.add_component("comp", BrokenComponent()) return pipe, [PipelineRunData({"comp": {"a": 1}})] @given("a pipeline that has a component with only default inputs", target_fixture="pipeline_data") -def pipeline_that_has_a_component_with_only_default_inputs(): +def pipeline_that_has_a_component_with_only_default_inputs(pipeline_class): FakeGenerator = component_class( "FakeGenerator", input_types={"prompt": str}, output_types={"replies": List[str]}, output={"replies": ["Paris"]} ) @@ -898,7 +898,7 @@ def pipeline_that_has_a_component_with_only_default_inputs(): "Question: {{ query }}" ) - pipe = Pipeline(max_runs_per_component=1) + pipe = pipeline_class(max_runs_per_component=1) pipe.add_component("retriever", InMemoryBM25Retriever(document_store=doc_store)) pipe.add_component("prompt_builder", PromptBuilder(template=template)) @@ -999,7 +999,9 @@ def pipeline_that_has_a_component_with_only_default_inputs(): "a pipeline that has a component with only default inputs as first to run and receives inputs from a loop", target_fixture="pipeline_data", ) -def pipeline_that_has_a_component_with_only_default_inputs_as_first_to_run_and_receives_inputs_from_a_loop(): +def pipeline_that_has_a_component_with_only_default_inputs_as_first_to_run_and_receives_inputs_from_a_loop( + pipeline_class, +): """ This tests verifies that a Pipeline doesn't get stuck running in a loop if it has all the following characterics: @@ -1050,7 +1052,7 @@ def fake_generator_run(self, generation_kwargs: Optional[Dict[str, Any]] = None, ] ) - pipe = Pipeline(max_runs_per_component=1) + pipe = pipeline_class(max_runs_per_component=1) pipe.add_component("prompt_builder", PromptBuilder(template=template)) pipe.add_component("generator", FakeGenerator()) @@ -1106,8 +1108,8 @@ def fake_generator_run(self, generation_kwargs: Optional[Dict[str, Any]] = None, "a pipeline that has multiple branches that merge into a component with a single variadic input", target_fixture="pipeline_data", ) -def pipeline_that_has_multiple_branches_that_merge_into_a_component_with_a_single_variadic_input(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_has_multiple_branches_that_merge_into_a_component_with_a_single_variadic_input(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("add_one", AddFixedValue()) pipeline.add_component("parity", Remainder(divisor=2)) pipeline.add_component("add_ten", AddFixedValue(add=10)) @@ -1159,8 +1161,10 @@ def pipeline_that_has_multiple_branches_that_merge_into_a_component_with_a_singl "a pipeline that has multiple branches of different lengths that merge into a component with a single variadic input", target_fixture="pipeline_data", ) -def pipeline_that_has_multiple_branches_of_different_lengths_that_merge_into_a_component_with_a_single_variadic_input(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_has_multiple_branches_of_different_lengths_that_merge_into_a_component_with_a_single_variadic_input( + pipeline_class, +): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("first_addition", AddFixedValue(add=2)) pipeline.add_component("second_addition", AddFixedValue(add=2)) pipeline.add_component("third_addition", AddFixedValue(add=2)) @@ -1192,8 +1196,8 @@ def pipeline_that_has_multiple_branches_of_different_lengths_that_merge_into_a_c @given("a pipeline that is linear and returns intermediate outputs", target_fixture="pipeline_data") -def pipeline_that_is_linear_and_returns_intermediate_outputs(): - pipeline = Pipeline(max_runs_per_component=1) +def pipeline_that_is_linear_and_returns_intermediate_outputs(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("first_addition", AddFixedValue(add=2)) pipeline.add_component("second_addition", AddFixedValue()) pipeline.add_component("double", Double()) @@ -1232,8 +1236,8 @@ def pipeline_that_is_linear_and_returns_intermediate_outputs(): @given("a pipeline that has a loop and returns intermediate outputs from it", target_fixture="pipeline_data") -def pipeline_that_has_a_loop_and_returns_intermediate_outputs_from_it(): - pipeline = Pipeline(max_runs_per_component=10) +def pipeline_that_has_a_loop_and_returns_intermediate_outputs_from_it(pipeline_class): + pipeline = pipeline_class(max_runs_per_component=10) pipeline.add_component("add_one", AddFixedValue(add=1)) pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) pipeline.add_component("below_10", Threshold(threshold=10)) @@ -1297,7 +1301,7 @@ def pipeline_that_has_a_loop_and_returns_intermediate_outputs_from_it(): @given( "a pipeline that is linear and returns intermediate outputs from multiple sockets", target_fixture="pipeline_data" ) -def pipeline_that_is_linear_and_returns_intermediate_outputs_from_multiple_sockets(): +def pipeline_that_is_linear_and_returns_intermediate_outputs_from_multiple_sockets(pipeline_class): @component class DoubleWithOriginal: """ @@ -1308,7 +1312,7 @@ class DoubleWithOriginal: def run(self, value: int): return {"value": value * 2, "original": value} - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("first_addition", AddFixedValue(add=2)) pipeline.add_component("second_addition", AddFixedValue()) pipeline.add_component("double", DoubleWithOriginal()) @@ -1350,14 +1354,14 @@ def run(self, value: int): "a pipeline that has a component with default inputs that doesn't receive anything from its sender", target_fixture="pipeline_data", ) -def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender(): +def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender(pipeline_class): routes = [ {"condition": "{{'reisen' in sentence}}", "output": "German", "output_name": "language_1", "output_type": str}, {"condition": "{{'viajar' in sentence}}", "output": "Spanish", "output_name": "language_2", "output_type": str}, ] router = ConditionalRouter(routes) - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("router", router) pipeline.add_component("pb", PromptBuilder(template="Ok, I know, that's {{language}}")) pipeline.connect("router.language_2", "pb.language") @@ -1386,7 +1390,9 @@ def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anythi "a pipeline that has a component with default inputs that doesn't receive anything from its sender but receives input from user", target_fixture="pipeline_data", ) -def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender_but_receives_input_from_user(): +def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender_but_receives_input_from_user( + pipeline_class, +): prompt = PromptBuilder( template="""Please generate an SQL query. The query should answer the following Question: {{ question }}; If the question cannot be answered given the provided table and columns, return 'no_answer' @@ -1436,7 +1442,7 @@ def run(self, query: str): ) fallback_llm = FakeGenerator() - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("prompt", prompt) pipeline.add_component("llm", llm) pipeline.add_component("router", router) @@ -1615,7 +1621,9 @@ def run(self, query: str): "a pipeline that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user", target_fixture="pipeline_data", ) -def pipeline_that_has_a_loop_and_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender_but_receives_input_from_user(): +def pipeline_that_has_a_loop_and_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender_but_receives_input_from_user( + pipeline_class, +): template = """ You are an experienced and accurate Turkish CX speacialist that classifies customer comments into pre-defined categories below:\n Negative experience labels: @@ -1664,7 +1672,7 @@ def run(self, prompt: str): llm = FakeGenerator() validator = FakeOutputValidator() - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("prompt_builder", prompt_builder) pipeline.add_component("llm", llm) @@ -1790,7 +1798,9 @@ def run(self, prompt: str): "a pipeline that has multiple components with only default inputs and are added in a different order from the order of execution", target_fixture="pipeline_data", ) -def pipeline_that_has_multiple_components_with_only_default_inputs_and_are_added_in_a_different_order_from_the_order_of_execution(): +def pipeline_that_has_multiple_components_with_only_default_inputs_and_are_added_in_a_different_order_from_the_order_of_execution( + pipeline_class, +): prompt_builder1 = PromptBuilder( template=""" You are a spellchecking system. Check the given query and fill in the corrected query. @@ -1851,7 +1861,7 @@ class FakeGenerator: def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): return {"replies": ["This is a reply"], "meta": {"meta_key": "meta_value"}} - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component(name="retriever", instance=FakeRetriever()) pipeline.add_component(name="ranker", instance=FakeRanker()) pipeline.add_component(name="prompt_builder2", instance=prompt_builder2) @@ -1954,8 +1964,8 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): @given("a pipeline that is linear with conditional branching and multiple joins", target_fixture="pipeline_data") -def that_is_linear_with_conditional_branching_and_multiple_joins(): - pipeline = Pipeline() +def that_is_linear_with_conditional_branching_and_multiple_joins(pipeline_class): + pipeline = pipeline_class() @component class FakeRouter: @@ -2065,7 +2075,7 @@ def run(self, query_embedding: List[float]): @given("a pipeline that is a simple agent", target_fixture="pipeline_data") -def that_is_a_simple_agent(): +def that_is_a_simple_agent(pipeline_class): search_message_template = """ Given these web search results: @@ -2170,7 +2180,7 @@ def run(self, query: str): } # main part - pipeline = Pipeline() + pipeline = pipeline_class() pipeline.add_component("main_input", BranchJoiner(List[ChatMessage])) pipeline.add_component("prompt_builder", ChatPromptBuilder(variables=["query"])) pipeline.add_component("llm", FakeThoughtActionOpenAIChatGenerator()) @@ -2507,7 +2517,7 @@ def run(self, replies: List[ChatMessage]): @given("a pipeline that has a variadic component that receives partial inputs", target_fixture="pipeline_data") -def that_has_a_variadic_component_that_receives_partial_inputs(): +def that_has_a_variadic_component_that_receives_partial_inputs(pipeline_class): @component class ConditionalDocumentCreator: def __init__(self, content: str): @@ -2519,7 +2529,7 @@ def run(self, create_document: bool = False): return {"documents": [Document(id=self._content, content=self._content)]} return {"noop": None} - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("first_creator", ConditionalDocumentCreator(content="First document")) pipeline.add_component("second_creator", ConditionalDocumentCreator(content="Second document")) pipeline.add_component("third_creator", ConditionalDocumentCreator(content="Third document")) @@ -2588,7 +2598,7 @@ def run(self, create_document: bool = False): "a pipeline that has a variadic component that receives partial inputs in a different order", target_fixture="pipeline_data", ) -def that_has_a_variadic_component_that_receives_partial_inputs_different_order(): +def that_has_a_variadic_component_that_receives_partial_inputs_different_order(pipeline_class): @component class ConditionalDocumentCreator: def __init__(self, content: str): @@ -2600,7 +2610,7 @@ def run(self, create_document: bool = False): return {"documents": [Document(id=self._content, content=self._content)]} return {"noop": None} - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("third_creator", ConditionalDocumentCreator(content="Third document")) pipeline.add_component("first_creator", ConditionalDocumentCreator(content="First document")) pipeline.add_component("second_creator", ConditionalDocumentCreator(content="Second document")) @@ -2666,10 +2676,10 @@ def run(self, create_document: bool = False): @given("a pipeline that has an answer joiner variadic component", target_fixture="pipeline_data") -def that_has_an_answer_joiner_variadic_component(): +def that_has_an_answer_joiner_variadic_component(pipeline_class): query = "What's Natural Language Processing?" - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("answer_builder_1", AnswerBuilder()) pipeline.add_component("answer_builder_2", AnswerBuilder()) pipeline.add_component("answer_joiner", AnswerJoiner()) @@ -2751,7 +2761,9 @@ def that_has_an_answer_joiner_variadic_component(): "a pipeline that is linear and a component in the middle receives optional input from other components and input from the user", target_fixture="pipeline_data", ) -def that_is_linear_and_a_component_in_the_middle_receives_optional_input_from_other_components_and_input_from_the_user(): +def that_is_linear_and_a_component_in_the_middle_receives_optional_input_from_other_components_and_input_from_the_user( + pipeline_class, +): @component class QueryMetadataExtractor: @component.output_types(filters=Dict[str, str]) @@ -2788,7 +2800,7 @@ def run(self, prompt: str): document_store = InMemoryDocumentStore(bm25_algorithm="BM25Plus") document_store.write_documents(documents=documents, policy=DuplicatePolicy.OVERWRITE) - pipeline = Pipeline() + pipeline = pipeline_class() pipeline.add_component(instance=PromptBuilder('{"disease": "Alzheimer", "year": 2023}'), name="builder") pipeline.add_component(instance=QueryMetadataExtractor(), name="metadata_extractor") pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever") @@ -2852,7 +2864,7 @@ def run(self, prompt: str): @given("a pipeline that has a cycle that would get it stuck", target_fixture="pipeline_data") -def that_has_a_cycle_that_would_get_it_stuck(): +def that_has_a_cycle_that_would_get_it_stuck(pipeline_class): template = """ You are an experienced and accurate Turkish CX speacialist that classifies customer comments into pre-defined categories below:\n Negative experience labels: @@ -2903,7 +2915,7 @@ def run(self, prompt: str): llm = FakeGenerator() validator = FakeOutputValidator() - pipeline = Pipeline(max_runs_per_component=1) + pipeline = pipeline_class(max_runs_per_component=1) pipeline.add_component("prompt_builder", prompt_builder) pipeline.add_component("llm", llm) @@ -2920,7 +2932,7 @@ def run(self, prompt: str): @given("a pipeline that has a loop in the middle", target_fixture="pipeline_data") -def that_has_a_loop_in_the_middle(): +def that_has_a_loop_in_the_middle(pipeline_class): @component class FakeGenerator: @component.output_types(replies=List[str]) @@ -2954,7 +2966,7 @@ def run(self, prompt: str): }, ] - pipeline = Pipeline(max_runs_per_component=20) + pipeline = pipeline_class(max_runs_per_component=20) pipeline.add_component("prompt_cleaner", PromptCleaner()) pipeline.add_component("prompt_builder", PromptBuilder(template="", variables=["question", "invalid_replies"])) pipeline.add_component("llm", FakeGenerator()) @@ -3013,8 +3025,8 @@ def run(self, prompt: str): @given("a pipeline that has variadic component that receives a conditional input", target_fixture="pipeline_data") -def that_has_variadic_component_that_receives_a_conditional_input(): - pipe = Pipeline(max_runs_per_component=1) +def that_has_variadic_component_that_receives_a_conditional_input(pipeline_class): + pipe = pipeline_class(max_runs_per_component=1) routes = [ { "condition": "{{ documents|length > 1 }}", @@ -3279,11 +3291,11 @@ def run(self, documents: List[Document]): @given("a pipeline that has a string variadic component", target_fixture="pipeline_data") -def that_has_a_string_variadic_component(): +def that_has_a_string_variadic_component(pipeline_class): string_1 = "What's Natural Language Processing?" string_2 = "What's is life?" - pipeline = Pipeline() + pipeline = pipeline_class() pipeline.add_component("prompt_builder_1", PromptBuilder("Builder 1: {{query}}")) pipeline.add_component("prompt_builder_2", PromptBuilder("Builder 2: {{query}}")) pipeline.add_component("string_joiner", StringJoiner()) @@ -3318,7 +3330,7 @@ def that_has_a_string_variadic_component(): @given("a pipeline that is an agent that can use RAG", target_fixture="pipeline_data") -def an_agent_that_can_use_RAG(): +def an_agent_that_can_use_RAG(pipeline_class): @component class FixedGenerator: def __init__(self, replies): @@ -3397,7 +3409,7 @@ def run(self, query: str): answer_builder = AnswerBuilder() - pp = Pipeline(max_runs_per_component=2) + pp = pipeline_class(max_runs_per_component=2) pp.add_component("joiner", joiner) pp.add_component("rag_llm", rag_llm) @@ -3566,7 +3578,7 @@ def run(self, query: str): @given("a pipeline that has a feedback loop", target_fixture="pipeline_data") -def has_feedback_loop(): +def has_feedback_loop(pipeline_class): @component class FixedGenerator: def __init__(self, replies): @@ -3627,7 +3639,7 @@ def run(self, prompt: str): answer_builder = AnswerBuilder() - pp = Pipeline(max_runs_per_component=100) + pp = pipeline_class(max_runs_per_component=100) pp.add_component("code_llm", code_llm) pp.add_component("code_prompt", code_prompt) @@ -3730,7 +3742,7 @@ def run(self, prompt: str): @given("a pipeline created in a non-standard order that has a loop", target_fixture="pipeline_data") -def has_non_standard_order_loop(): +def has_non_standard_order_loop(pipeline_class): @component class FixedGenerator: def __init__(self, replies): @@ -3791,7 +3803,7 @@ def run(self, prompt: str): answer_builder = AnswerBuilder() - pp = Pipeline(max_runs_per_component=100) + pp = pipeline_class(max_runs_per_component=100) pp.add_component("concatenator", concatenator) pp.add_component("code_llm", code_llm) @@ -3895,7 +3907,7 @@ def run(self, prompt: str): @given("a pipeline that has an agent with a feedback cycle", target_fixture="pipeline_data") -def agent_with_feedback_cycle(): +def agent_with_feedback_cycle(pipeline_class): @component class FixedGenerator: def __init__(self, replies): @@ -3985,7 +3997,7 @@ def run(self, replies: List[str]): joiner = BranchJoiner(type_=str) agent_concatenator = OutputAdapter(template="{{current_prompt + '\n' + files}}", output_type=str) - pp = Pipeline(max_runs_per_component=100) + pp = pipeline_class(max_runs_per_component=100) pp.add_component("code_prompt", code_prompt) pp.add_component("joiner", joiner) @@ -4629,7 +4641,7 @@ def run(self, replies: List[str]): @given("a pipeline that passes outputs that are consumed in cycle to outside the cycle", target_fixture="pipeline_data") -def passes_outputs_outside_cycle(): +def passes_outputs_outside_cycle(pipeline_class): @component class FixedGenerator: def __init__(self, replies): @@ -4705,7 +4717,7 @@ def generate_santa_sleigh(): answer_builder = AnswerBuilderWithPrompt() - pp = Pipeline(max_runs_per_component=100) + pp = pipeline_class(max_runs_per_component=100) pp.add_component("concatenator", concatenator) pp.add_component("code_llm", code_llm) @@ -4899,7 +4911,7 @@ def generate_santa_sleigh(): @given("a pipeline with a component that has dynamic default inputs", target_fixture="pipeline_data") -def pipeline_with_dynamic_defaults(): +def pipeline_with_dynamic_defaults(pipeline_class): @component class ParrotWithDynamicDefaultInputs: def __init__(self, input_variable: str): @@ -4911,7 +4923,7 @@ def run(self, **kwargs): return {"response": kwargs[self.input_variable]} parrot = ParrotWithDynamicDefaultInputs("parrot") - pipeline = Pipeline() + pipeline = pipeline_class() pipeline.add_component("parrot", parrot) return ( pipeline, @@ -4931,7 +4943,7 @@ def run(self, **kwargs): @given("a pipeline with a component that has variadic dynamic default inputs", target_fixture="pipeline_data") -def pipeline_with_variadic_dynamic_defaults(): +def pipeline_with_variadic_dynamic_defaults(pipeline_class): @component class ParrotWithVariadicDynamicDefaultInputs: def __init__(self, input_variable: str): @@ -4943,7 +4955,7 @@ def run(self, **kwargs): return {"response": kwargs[self.input_variable]} parrot = ParrotWithVariadicDynamicDefaultInputs("parrot") - pipeline = Pipeline() + pipeline = pipeline_class() pipeline.add_component("parrot", parrot) return ( pipeline, @@ -4963,7 +4975,7 @@ def run(self, **kwargs): @given("a pipeline that is a file conversion pipeline with two joiners", target_fixture="pipeline_data") -def pipeline_that_converts_files(): +def pipeline_that_converts_files(pipeline_class): csv_data = """ some,header,row 0,1,0 @@ -4988,7 +5000,7 @@ def pipeline_that_converts_files(): b_joiner = DocumentJoiner() a_joiner = DocumentJoiner() - pp = Pipeline(max_runs_per_component=1) + pp = pipeline_class(max_runs_per_component=1) pp.add_component("router", router) pp.add_component("splitter", splitter) diff --git a/test/core/pipeline/test_async_pipeline.py b/test/core/pipeline/test_async_pipeline.py new file mode 100644 index 0000000000..474449f782 --- /dev/null +++ b/test/core/pipeline/test_async_pipeline.py @@ -0,0 +1,20 @@ +import asyncio + +from haystack import AsyncPipeline + + +def test_async_pipeline_reentrance(waiting_component, spying_tracer): + pp = AsyncPipeline() + pp.add_component("wait", waiting_component()) + + run_data = [{"wait_for": 1}, {"wait_for": 2}] + + async def run_all(): + # Create concurrent tasks for each pipeline run + tasks = [pp.run_async(data) for data in run_data] + await asyncio.gather(*tasks) + + asyncio.run(run_all()) + component_spans = [sp for sp in spying_tracer.spans if sp.operation_name == "haystack.component.run_async"] + for span in component_spans: + assert span.tags["haystack.component.visits"] == 1 diff --git a/test/test_telemetry.py b/test/test_telemetry.py index 2b6f9e7e71..18b88c20fe 100644 --- a/test/test_telemetry.py +++ b/test/test_telemetry.py @@ -7,13 +7,15 @@ import pytest -from haystack import Pipeline, component +from haystack import AsyncPipeline, Pipeline, component +from haystack.core.serialization import generate_qualified_class_name from haystack.telemetry._telemetry import pipeline_running from haystack.utils.auth import Secret, TokenSecret +@pytest.mark.parametrize("pipeline_class", [Pipeline, AsyncPipeline]) @patch("haystack.telemetry._telemetry.telemetry") -def test_pipeline_running(telemetry): +def test_pipeline_running(telemetry, pipeline_class): telemetry.send_event = Mock() @component @@ -25,15 +27,17 @@ def _get_telemetry_data(self): def run(self): pass - pipe = Pipeline() + pipe = pipeline_class() pipe.add_component("component", Component()) pipeline_running(pipe) + expected_type = generate_qualified_class_name(type(pipe)) # First run is always sent telemetry.send_event.assert_called_once_with( "Pipeline run (2.x)", { "pipeline_id": str(id(pipe)), + "pipeline_type": expected_type, "runs": 1, "components": {"test.test_telemetry.Component": [{"name": "component", "key": "values"}]}, }, @@ -53,6 +57,7 @@ def run(self): "Pipeline run (2.x)", { "pipeline_id": str(id(pipe)), + "pipeline_type": expected_type, "runs": 3, "components": {"test.test_telemetry.Component": [{"name": "component", "key": "values"}]}, }, @@ -82,6 +87,7 @@ def run(self): "Pipeline run (2.x)", { "pipeline_id": str(id(pipe)), + "pipeline_type": "haystack.core.pipeline.pipeline.Pipeline", "runs": 1, "components": {"test.test_telemetry.Component": [{"name": "component", "key": "values"}]}, },