Skip to content

Commit

Permalink
feat: Add a ListJoiner component (#8810)
Browse files Browse the repository at this point in the history
* Add a ListJoiner

* Add tests and release notes
  • Loading branch information
Amnah199 authored Feb 5, 2025
1 parent d2348ad commit b0809b7
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 1 deletion.
3 changes: 2 additions & 1 deletion haystack/components/joiners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .answer_joiner import AnswerJoiner
from .branch import BranchJoiner
from .document_joiner import DocumentJoiner
from .list_joiner import ListJoiner
from .string_joiner import StringJoiner

__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner", "StringJoiner"]
__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner", "StringJoiner", "ListJoiner"]
105 changes: 105 additions & 0 deletions haystack/components/joiners/list_joiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from itertools import chain
from typing import Any, Dict, Type

from haystack import component, default_from_dict, default_to_dict
from haystack.core.component.types import Variadic
from haystack.utils import deserialize_type, serialize_type


@component
class ListJoiner:
"""
A component that joins multiple lists into a single flat list.
The ListJoiner receives multiple lists of the same type and concatenates them into a single flat list.
The output order respects the pipeline's execution sequence, with earlier inputs being added first.
Usage example:
```python
from haystack.components.builders import ChatPromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack import Pipeline
from haystack.components.joiners import ListJoiner
from typing import List
user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")]
feedback_prompt = \"""
You are given a question and an answer.
Your task is to provide a score and a brief feedback on the answer.
Question: {{query}}
Answer: {{response}}
\"""
feedback_message = [ChatMessage.from_system(feedback_prompt)]
prompt_builder = ChatPromptBuilder(template=user_message)
feedback_prompt_builder = ChatPromptBuilder(template=feedback_message)
llm = OpenAIChatGenerator(model="gpt-4o-mini")
feedback_llm = OpenAIChatGenerator(model="gpt-4o-mini")
pipe = Pipeline()
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", llm)
pipe.add_component("feedback_prompt_builder", feedback_prompt_builder)
pipe.add_component("feedback_llm", feedback_llm)
pipe.add_component("list_joiner", ListJoiner(List[ChatMessage]))
pipe.connect("prompt_builder.prompt", "llm.messages")
pipe.connect("prompt_builder.prompt", "list_joiner")
pipe.connect("llm.replies", "list_joiner")
pipe.connect("llm.replies", "feedback_prompt_builder.response")
pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages")
pipe.connect("feedback_llm.replies", "list_joiner")
query = "What is nuclear physics?"
ans = pipe.run(data={"prompt_builder": {"template_variables":{"query": query}},
"feedback_prompt_builder": {"template_variables":{"query": query}}})
print(ans["list_joiner"]["values"])
```
"""

def __init__(self, list_type_: Type):
"""
Creates a ListJoiner component.
:param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]).
All input lists must be of this type.
"""
self.list_type_ = list_type_
component.set_output_types(self, values=list_type_)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns: Dictionary with serialized data.
"""
return default_to_dict(self, list_type_=serialize_type(self.list_type_))

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner":
"""
Deserializes the component from a dictionary.
:param data: Dictionary to deserialize from.
:returns: Deserialized component.
"""
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
return default_from_dict(cls, data)

def run(self, values: Variadic[Any]) -> Dict[str, Any]:
"""
Joins multiple lists into a single flat list.
:param values:The list to be joined.
:returns: Dictionary with 'values' key containing the joined list.
"""
result = list(chain(*values))
return {"values": result}
4 changes: 4 additions & 0 deletions releasenotes/notes/add-list-joiner-4f0ea84e195fa461.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Added a new component `ListJoiner` which joins lists of values from different components to a single list.
71 changes: 71 additions & 0 deletions test/components/joiners/test_list_joiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import List

from haystack import Document
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.answer import GeneratedAnswer
from haystack.components.joiners.list_joiner import ListJoiner


class TestListJoiner:
def test_init(self):
joiner = ListJoiner(List[ChatMessage])
assert isinstance(joiner, ListJoiner)
assert joiner.list_type_ == List[ChatMessage]

def test_to_dict(self):
joiner = ListJoiner(List[ChatMessage])
data = joiner.to_dict()
assert data == {
"type": "haystack.components.joiners.list_joiner.ListJoiner",
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
}

def test_from_dict(self):
data = {
"type": "haystack.components.joiners.list_joiner.ListJoiner",
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
}
list_joiner = ListJoiner.from_dict(data)
assert isinstance(list_joiner, ListJoiner)
assert list_joiner.list_type_ == List[ChatMessage]

def test_empty_list(self):
joiner = ListJoiner(List[ChatMessage])
result = joiner.run([])
assert result == {"values": []}

def test_list_of_empty_lists(self):
joiner = ListJoiner(List[ChatMessage])
result = joiner.run([[], []])
assert result == {"values": []}

def test_single_list_of_chat_messages(self):
joiner = ListJoiner(List[ChatMessage])
messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")]
result = joiner.run([messages])
assert result == {"values": messages}

def test_multiple_lists_of_chat_messages(self):
joiner = ListJoiner(List[ChatMessage])
messages1 = [ChatMessage.from_user("Hello")]
messages2 = [ChatMessage.from_assistant("Hi there")]
messages3 = [ChatMessage.from_system("System message")]
result = joiner.run([messages1, messages2, messages3])
assert result == {"values": messages1 + messages2 + messages3}

def test_list_of_generated_answers(self):
joiner = ListJoiner(List[GeneratedAnswer])
answers1 = [GeneratedAnswer(query="q1", data="a1", meta={}, documents=[Document(content="d1")])]
answers2 = [GeneratedAnswer(query="q2", data="a2", meta={}, documents=[Document(content="d2")])]
result = joiner.run([answers1, answers2])
assert result == {"values": answers1 + answers2}

def test_mixed_empty_and_non_empty_lists(self):
joiner = ListJoiner(List[ChatMessage])
messages = [ChatMessage.from_user("Hello")]
result = joiner.run([messages, [], messages])
assert result == {"values": messages + messages}

0 comments on commit b0809b7

Please sign in to comment.