Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a ListJoiner component #8810

Merged
merged 6 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion haystack/components/joiners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .answer_joiner import AnswerJoiner
from .branch import BranchJoiner
from .document_joiner import DocumentJoiner
from .list_joiner import ListJoiner
from .string_joiner import StringJoiner

__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner", "StringJoiner"]
__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner", "StringJoiner", "ListJoiner"]
105 changes: 105 additions & 0 deletions haystack/components/joiners/list_joiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from itertools import chain
from typing import Any, Dict, Type

from haystack import component, default_from_dict, default_to_dict
from haystack.core.component.types import Variadic
from haystack.utils import deserialize_type, serialize_type


@component
class ListJoiner:
"""
A component that joins multiple lists into a single flat list.
The ListJoiner receives multiple lists of the same type and concatenates them into a single flat list.
The output order respects the pipeline's execution sequence, with earlier inputs being added first.
Usage example:
```python
from haystack.components.builders import ChatPromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack import Pipeline
from haystack.components.joiners import ListJoiner
from typing import List
user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")]
feedback_prompt = \"""
You are given a question and an answer.
Your task is to provide a score and a brief feedback on the answer.
Question: {{query}}
Answer: {{response}}
\"""
feedback_message = [ChatMessage.from_system(feedback_prompt)]
prompt_builder = ChatPromptBuilder(template=user_message)
feedback_prompt_builder = ChatPromptBuilder(template=feedback_message)
llm = OpenAIChatGenerator(model="gpt-4o-mini")
feedback_llm = OpenAIChatGenerator(model="gpt-4o-mini")
pipe = Pipeline()
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", llm)
pipe.add_component("feedback_prompt_builder", feedback_prompt_builder)
pipe.add_component("feedback_llm", feedback_llm)
pipe.add_component("list_joiner", ListJoiner(List[ChatMessage]))
pipe.connect("prompt_builder.prompt", "llm.messages")
pipe.connect("prompt_builder.prompt", "list_joiner")
pipe.connect("llm.replies", "list_joiner")
pipe.connect("llm.replies", "feedback_prompt_builder.response")
pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages")
pipe.connect("feedback_llm.replies", "list_joiner")
query = "What is nuclear physics?"
ans = pipe.run(data={"prompt_builder": {"template_variables":{"query": query}},
"feedback_prompt_builder": {"template_variables":{"query": query}}})
print(ans["list_joiner"]["values"])
```
"""

def __init__(self, list_type_: Type):
"""
Creates a ListJoiner component.
:param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]).
All input lists must be of this type.
"""
self.list_type_ = list_type_
component.set_output_types(self, values=list_type_)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns: Dictionary with serialized data.
"""
return default_to_dict(self, list_type_=serialize_type(self.list_type_))

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner":
"""
Deserializes the component from a dictionary.
:param data: Dictionary to deserialize from.
:returns: Deserialized component.
"""
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
return default_from_dict(cls, data)

def run(self, values: Variadic[Any]) -> Dict[str, Any]:
"""
Joins multiple lists into a single flat list.
:param values:The list to be joined.
:returns: Dictionary with 'values' key containing the joined list.
"""
result = list(chain(*values))
return {"values": result}
4 changes: 4 additions & 0 deletions releasenotes/notes/add-list-joiner-4f0ea84e195fa461.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Added a new component `ListJoiner` which joins lists of values from different components to a single list.
71 changes: 71 additions & 0 deletions test/components/joiners/test_list_joiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import List

from haystack import Document
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.answer import GeneratedAnswer
from haystack.components.joiners.list_joiner import ListJoiner


class TestListJoiner:
def test_init(self):
joiner = ListJoiner(List[ChatMessage])
assert isinstance(joiner, ListJoiner)
assert joiner.list_type_ == List[ChatMessage]

def test_to_dict(self):
joiner = ListJoiner(List[ChatMessage])
data = joiner.to_dict()
assert data == {
"type": "haystack.components.joiners.list_joiner.ListJoiner",
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
}

def test_from_dict(self):
data = {
"type": "haystack.components.joiners.list_joiner.ListJoiner",
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
}
list_joiner = ListJoiner.from_dict(data)
assert isinstance(list_joiner, ListJoiner)
assert list_joiner.list_type_ == List[ChatMessage]

def test_empty_list(self):
joiner = ListJoiner(List[ChatMessage])
result = joiner.run([])
assert result == {"values": []}

def test_list_of_empty_lists(self):
joiner = ListJoiner(List[ChatMessage])
result = joiner.run([[], []])
assert result == {"values": []}

def test_single_list_of_chat_messages(self):
joiner = ListJoiner(List[ChatMessage])
messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")]
result = joiner.run([messages])
assert result == {"values": messages}

def test_multiple_lists_of_chat_messages(self):
joiner = ListJoiner(List[ChatMessage])
messages1 = [ChatMessage.from_user("Hello")]
messages2 = [ChatMessage.from_assistant("Hi there")]
messages3 = [ChatMessage.from_system("System message")]
result = joiner.run([messages1, messages2, messages3])
assert result == {"values": messages1 + messages2 + messages3}

def test_list_of_generated_answers(self):
joiner = ListJoiner(List[GeneratedAnswer])
answers1 = [GeneratedAnswer(query="q1", data="a1", meta={}, documents=[Document(content="d1")])]
answers2 = [GeneratedAnswer(query="q2", data="a2", meta={}, documents=[Document(content="d2")])]
result = joiner.run([answers1, answers2])
assert result == {"values": answers1 + answers2}

def test_mixed_empty_and_non_empty_lists(self):
joiner = ListJoiner(List[ChatMessage])
messages = [ChatMessage.from_user("Hello")]
result = joiner.run([messages, [], messages])
assert result == {"values": messages + messages}