-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add a
ListJoiner
component (#8810)
* Add a ListJoiner * Add tests and release notes
- Loading branch information
Showing
4 changed files
with
182 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from itertools import chain | ||
from typing import Any, Dict, Type | ||
|
||
from haystack import component, default_from_dict, default_to_dict | ||
from haystack.core.component.types import Variadic | ||
from haystack.utils import deserialize_type, serialize_type | ||
|
||
|
||
@component | ||
class ListJoiner: | ||
""" | ||
A component that joins multiple lists into a single flat list. | ||
The ListJoiner receives multiple lists of the same type and concatenates them into a single flat list. | ||
The output order respects the pipeline's execution sequence, with earlier inputs being added first. | ||
Usage example: | ||
```python | ||
from haystack.components.builders import ChatPromptBuilder | ||
from haystack.components.generators.chat import OpenAIChatGenerator | ||
from haystack.dataclasses import ChatMessage | ||
from haystack import Pipeline | ||
from haystack.components.joiners import ListJoiner | ||
from typing import List | ||
user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")] | ||
feedback_prompt = \""" | ||
You are given a question and an answer. | ||
Your task is to provide a score and a brief feedback on the answer. | ||
Question: {{query}} | ||
Answer: {{response}} | ||
\""" | ||
feedback_message = [ChatMessage.from_system(feedback_prompt)] | ||
prompt_builder = ChatPromptBuilder(template=user_message) | ||
feedback_prompt_builder = ChatPromptBuilder(template=feedback_message) | ||
llm = OpenAIChatGenerator(model="gpt-4o-mini") | ||
feedback_llm = OpenAIChatGenerator(model="gpt-4o-mini") | ||
pipe = Pipeline() | ||
pipe.add_component("prompt_builder", prompt_builder) | ||
pipe.add_component("llm", llm) | ||
pipe.add_component("feedback_prompt_builder", feedback_prompt_builder) | ||
pipe.add_component("feedback_llm", feedback_llm) | ||
pipe.add_component("list_joiner", ListJoiner(List[ChatMessage])) | ||
pipe.connect("prompt_builder.prompt", "llm.messages") | ||
pipe.connect("prompt_builder.prompt", "list_joiner") | ||
pipe.connect("llm.replies", "list_joiner") | ||
pipe.connect("llm.replies", "feedback_prompt_builder.response") | ||
pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages") | ||
pipe.connect("feedback_llm.replies", "list_joiner") | ||
query = "What is nuclear physics?" | ||
ans = pipe.run(data={"prompt_builder": {"template_variables":{"query": query}}, | ||
"feedback_prompt_builder": {"template_variables":{"query": query}}}) | ||
print(ans["list_joiner"]["values"]) | ||
``` | ||
""" | ||
|
||
def __init__(self, list_type_: Type): | ||
""" | ||
Creates a ListJoiner component. | ||
:param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]). | ||
All input lists must be of this type. | ||
""" | ||
self.list_type_ = list_type_ | ||
component.set_output_types(self, values=list_type_) | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serializes the component to a dictionary. | ||
:returns: Dictionary with serialized data. | ||
""" | ||
return default_to_dict(self, list_type_=serialize_type(self.list_type_)) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner": | ||
""" | ||
Deserializes the component from a dictionary. | ||
:param data: Dictionary to deserialize from. | ||
:returns: Deserialized component. | ||
""" | ||
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"]) | ||
return default_from_dict(cls, data) | ||
|
||
def run(self, values: Variadic[Any]) -> Dict[str, Any]: | ||
""" | ||
Joins multiple lists into a single flat list. | ||
:param values:The list to be joined. | ||
:returns: Dictionary with 'values' key containing the joined list. | ||
""" | ||
result = list(chain(*values)) | ||
return {"values": result} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
features: | ||
- | | ||
Added a new component `ListJoiner` which joins lists of values from different components to a single list. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import List | ||
|
||
from haystack import Document | ||
from haystack.dataclasses import ChatMessage | ||
from haystack.dataclasses.answer import GeneratedAnswer | ||
from haystack.components.joiners.list_joiner import ListJoiner | ||
|
||
|
||
class TestListJoiner: | ||
def test_init(self): | ||
joiner = ListJoiner(List[ChatMessage]) | ||
assert isinstance(joiner, ListJoiner) | ||
assert joiner.list_type_ == List[ChatMessage] | ||
|
||
def test_to_dict(self): | ||
joiner = ListJoiner(List[ChatMessage]) | ||
data = joiner.to_dict() | ||
assert data == { | ||
"type": "haystack.components.joiners.list_joiner.ListJoiner", | ||
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"}, | ||
} | ||
|
||
def test_from_dict(self): | ||
data = { | ||
"type": "haystack.components.joiners.list_joiner.ListJoiner", | ||
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"}, | ||
} | ||
list_joiner = ListJoiner.from_dict(data) | ||
assert isinstance(list_joiner, ListJoiner) | ||
assert list_joiner.list_type_ == List[ChatMessage] | ||
|
||
def test_empty_list(self): | ||
joiner = ListJoiner(List[ChatMessage]) | ||
result = joiner.run([]) | ||
assert result == {"values": []} | ||
|
||
def test_list_of_empty_lists(self): | ||
joiner = ListJoiner(List[ChatMessage]) | ||
result = joiner.run([[], []]) | ||
assert result == {"values": []} | ||
|
||
def test_single_list_of_chat_messages(self): | ||
joiner = ListJoiner(List[ChatMessage]) | ||
messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")] | ||
result = joiner.run([messages]) | ||
assert result == {"values": messages} | ||
|
||
def test_multiple_lists_of_chat_messages(self): | ||
joiner = ListJoiner(List[ChatMessage]) | ||
messages1 = [ChatMessage.from_user("Hello")] | ||
messages2 = [ChatMessage.from_assistant("Hi there")] | ||
messages3 = [ChatMessage.from_system("System message")] | ||
result = joiner.run([messages1, messages2, messages3]) | ||
assert result == {"values": messages1 + messages2 + messages3} | ||
|
||
def test_list_of_generated_answers(self): | ||
joiner = ListJoiner(List[GeneratedAnswer]) | ||
answers1 = [GeneratedAnswer(query="q1", data="a1", meta={}, documents=[Document(content="d1")])] | ||
answers2 = [GeneratedAnswer(query="q2", data="a2", meta={}, documents=[Document(content="d2")])] | ||
result = joiner.run([answers1, answers2]) | ||
assert result == {"values": answers1 + answers2} | ||
|
||
def test_mixed_empty_and_non_empty_lists(self): | ||
joiner = ListJoiner(List[ChatMessage]) | ||
messages = [ChatMessage.from_user("Hello")] | ||
result = joiner.run([messages, [], messages]) | ||
assert result == {"values": messages + messages} |