From fe9b1e29d409995204b3ffcf4309711cf0dd5430 Mon Sep 17 00:00:00 2001 From: mathislucka Date: Thu, 9 Jan 2025 17:25:55 +0100 Subject: [PATCH 01/41] CI: fix format after newly introduced formatting rules from ruff release (#8696) --- haystack/components/audio/whisper_local.py | 2 +- haystack/components/converters/openapi_functions.py | 3 +-- .../components/generators/chat/hugging_face_local.py | 2 +- haystack/components/rankers/lost_in_the_middle.py | 4 ++-- haystack/core/component/component.py | 10 +++++----- haystack/core/pipeline/draw.py | 6 +++--- haystack/document_stores/in_memory/document_store.py | 9 +++------ haystack/marshal/yaml.py | 3 +-- haystack/utils/filters.py | 3 +-- haystack/utils/hf.py | 2 +- test/components/audio/test_whisper_local.py | 12 ++++++------ .../converters/test_docx_file_to_document.py | 6 +++--- .../embedders/test_openai_document_embedder.py | 6 +++--- .../embedders/test_openai_text_embedder.py | 6 +++--- test/components/joiners/test_document_joiner.py | 6 +++--- .../preprocessors/test_document_cleaner.py | 8 ++------ test/components/routers/test_conditional_router.py | 6 +++--- test/core/pipeline/features/test_run.py | 4 ++-- 18 files changed, 44 insertions(+), 54 deletions(-) diff --git a/haystack/components/audio/whisper_local.py b/haystack/components/audio/whisper_local.py index 79ac83b144..54ec15c6f8 100644 --- a/haystack/components/audio/whisper_local.py +++ b/haystack/components/audio/whisper_local.py @@ -72,7 +72,7 @@ def __init__( whisper_import.check() if model not in get_args(WhisperLocalModel): raise ValueError( - f"Model name '{model}' not recognized. Choose one among: " f"{', '.join(get_args(WhisperLocalModel))}." + f"Model name '{model}' not recognized. Choose one among: {', '.join(get_args(WhisperLocalModel))}." ) self.model = model self.whisper_params = whisper_params or {} diff --git a/haystack/components/converters/openapi_functions.py b/haystack/components/converters/openapi_functions.py index acc5d2a232..0d13b9c59b 100644 --- a/haystack/components/converters/openapi_functions.py +++ b/haystack/components/converters/openapi_functions.py @@ -249,8 +249,7 @@ def _parse_openapi_spec(self, content: str) -> Dict[str, Any]: open_api_spec_content = yaml.safe_load(content) except yaml.YAMLError: error_message = ( - "Failed to parse the OpenAPI specification. " - "The content does not appear to be valid JSON or YAML.\n\n" + "Failed to parse the OpenAPI specification. The content does not appear to be valid JSON or YAML.\n\n" ) raise RuntimeError(error_message, content) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 1ad152f1e3..a79a6dcfa8 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -149,7 +149,7 @@ def __init__( # pylint: disable=too-many-positional-arguments if task not in PIPELINE_SUPPORTED_TASKS: raise ValueError( - f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}." + f"Task '{task}' is not supported. The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}." ) huggingface_pipeline_kwargs["task"] = task diff --git a/haystack/components/rankers/lost_in_the_middle.py b/haystack/components/rankers/lost_in_the_middle.py index f757fadddc..01df8fde30 100644 --- a/haystack/components/rankers/lost_in_the_middle.py +++ b/haystack/components/rankers/lost_in_the_middle.py @@ -51,7 +51,7 @@ def __init__(self, word_count_threshold: Optional[int] = None, top_k: Optional[i """ if isinstance(word_count_threshold, int) and word_count_threshold <= 0: raise ValueError( - f"Invalid value for word_count_threshold: {word_count_threshold}. " f"word_count_threshold must be > 0." + f"Invalid value for word_count_threshold: {word_count_threshold}. word_count_threshold must be > 0." ) if isinstance(top_k, int) and top_k <= 0: raise ValueError(f"top_k must be > 0, but got {top_k}") @@ -78,7 +78,7 @@ def run( """ if isinstance(word_count_threshold, int) and word_count_threshold <= 0: raise ValueError( - f"Invalid value for word_count_threshold: {word_count_threshold}. " f"word_count_threshold must be > 0." + f"Invalid value for word_count_threshold: {word_count_threshold}. word_count_threshold must be > 0." ) if isinstance(top_k, int) and top_k <= 0: raise ValueError(f"top_k must be > 0, but got {top_k}") diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index 567faa4871..d77fd77593 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -268,9 +268,9 @@ def __call__(cls, *args, **kwargs): try: pre_init_hook.in_progress = True named_positional_args = ComponentMeta._positional_to_kwargs(cls, args) - assert ( - set(named_positional_args.keys()).intersection(kwargs.keys()) == set() - ), "positional and keyword arguments overlap" + assert set(named_positional_args.keys()).intersection(kwargs.keys()) == set(), ( + "positional and keyword arguments overlap" + ) kwargs.update(named_positional_args) pre_init_hook.callback(cls, kwargs) instance = super().__call__(**kwargs) @@ -309,8 +309,8 @@ def _component_repr(component: Component) -> str: # We're explicitly ignoring the type here because we're sure that the component # has the __haystack_input__ and __haystack_output__ attributes at this point return ( - f'{result}\n{getattr(component, "__haystack_input__", "")}' - f'\n{getattr(component, "__haystack_output__", "")}' + f"{result}\n{getattr(component, '__haystack_input__', '')}" + f"\n{getattr(component, '__haystack_output__', '')}" ) diff --git a/haystack/core/pipeline/draw.py b/haystack/core/pipeline/draw.py index 83df791515..2e24bf9acd 100644 --- a/haystack/core/pipeline/draw.py +++ b/haystack/core/pipeline/draw.py @@ -124,7 +124,7 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: } states = { - comp: f"{comp}[\"{comp}
{type(data['instance']).__name__}{optional_inputs[comp]}\"]:::component" # noqa + comp: f'{comp}["{comp}
{type(data["instance"]).__name__}{optional_inputs[comp]}"]:::component' # noqa for comp, data in graph.nodes(data=True) if comp not in ["input", "output"] } @@ -139,11 +139,11 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: connections_list.append(conn_string) input_connections = [ - f"i{{*}}--\"{conn_data['label']}
{conn_data['conn_type']}\"--> {states[to_comp]}" + f'i{{*}}--"{conn_data["label"]}
{conn_data["conn_type"]}"--> {states[to_comp]}' for _, to_comp, conn_data in graph.out_edges("input", data=True) ] output_connections = [ - f"{states[from_comp]}--\"{conn_data['label']}
{conn_data['conn_type']}\"--> o{{*}}" + f'{states[from_comp]}--"{conn_data["label"]}
{conn_data["conn_type"]}"--> o{{*}}' for from_comp, _, conn_data in graph.in_edges("output", data=True) ] connections = "\n".join(connections_list + input_connections + output_connections) diff --git a/haystack/document_stores/in_memory/document_store.py b/haystack/document_stores/in_memory/document_store.py index 1aea1e50c3..ad469a002c 100644 --- a/haystack/document_stores/in_memory/document_store.py +++ b/haystack/document_stores/in_memory/document_store.py @@ -396,8 +396,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc if filters: if "operator" not in filters and "conditions" not in filters: raise ValueError( - "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering " - "for details." + "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." ) return [doc for doc in self.storage.values() if document_matches_filter(filters=filters, document=doc)] return list(self.storage.values()) @@ -506,8 +505,7 @@ def bm25_retrieval( if filters: if "operator" not in filters: raise ValueError( - "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering " - "for details." + "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." ) filters = {"operator": "AND", "conditions": [content_type_filter, filters]} else: @@ -574,8 +572,7 @@ def embedding_retrieval( # pylint: disable=too-many-positional-arguments return [] elif len(documents_with_embeddings) < len(all_documents): logger.info( - "Skipping some Documents that don't have an embedding. " - "To generate embeddings, use a DocumentEmbedder." + "Skipping some Documents that don't have an embedding. To generate embeddings, use a DocumentEmbedder." ) scores = self._compute_query_embedding_similarity_scores( diff --git a/haystack/marshal/yaml.py b/haystack/marshal/yaml.py index 615cd1916a..b9c5ffdf41 100644 --- a/haystack/marshal/yaml.py +++ b/haystack/marshal/yaml.py @@ -31,8 +31,7 @@ def marshal(self, dict_: Dict[str, Any]) -> str: return yaml.dump(dict_, Dumper=YamlDumper) except yaml.representer.RepresenterError as e: raise TypeError( - "Error dumping pipeline to YAML - Ensure that all pipeline " - "components only serialize basic Python types" + "Error dumping pipeline to YAML - Ensure that all pipeline components only serialize basic Python types" ) from e def unmarshal(self, data_: Union[str, bytes, bytearray]) -> Dict[str, Any]: diff --git a/haystack/utils/filters.py b/haystack/utils/filters.py index c8a3133e3c..bddb422efe 100644 --- a/haystack/utils/filters.py +++ b/haystack/utils/filters.py @@ -112,8 +112,7 @@ def _less_than_equal(document_value: Any, filter_value: Any) -> bool: def _in(document_value: Any, filter_value: Any) -> bool: if not isinstance(filter_value, list): msg = ( - f"Filter value must be a `list` when using operator 'in' or 'not in', " - f"received type '{type(filter_value)}'" + f"Filter value must be a `list` when using operator 'in' or 'not in', received type '{type(filter_value)}'" ) raise FilterError(msg) return any(_equal(e, document_value) for e in filter_value) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 6a83594ada..7ddca03046 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -205,7 +205,7 @@ def resolve_hf_pipeline_kwargs( # pylint: disable=too-many-positional-arguments task = model_info(huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]).pipeline_tag if task not in supported_tasks: - raise ValueError(f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(supported_tasks)}.") + raise ValueError(f"Task '{task}' is not supported. The supported tasks are: {', '.join(supported_tasks)}.") huggingface_pipeline_kwargs["task"] = task return huggingface_pipeline_kwargs diff --git a/test/components/audio/test_whisper_local.py b/test/components/audio/test_whisper_local.py index 28463c4ce6..394a9c4000 100644 --- a/test/components/audio/test_whisper_local.py +++ b/test/components/audio/test_whisper_local.py @@ -190,14 +190,14 @@ def test_whisper_local_transcriber(self, test_files_path): docs = output["documents"] assert len(docs) == 3 - assert all( - word in docs[0].content.strip().lower() for word in {"content", "the", "document"} - ), f"Expected words not found in: {docs[0].content.strip().lower()}" + assert all(word in docs[0].content.strip().lower() for word in {"content", "the", "document"}), ( + f"Expected words not found in: {docs[0].content.strip().lower()}" + ) assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"] - assert all( - word in docs[1].content.strip().lower() for word in {"context", "answer"} - ), f"Expected words not found in: {docs[1].content.strip().lower()}" + assert all(word in docs[1].content.strip().lower() for word in {"context", "answer"}), ( + f"Expected words not found in: {docs[1].content.strip().lower()}" + ) path = test_files_path / "audio" / "the context for this answer is here.wav" assert path.absolute() == docs[1].meta["audio_file"] diff --git a/test/components/converters/test_docx_file_to_document.py b/test/components/converters/test_docx_file_to_document.py index 9b4ee3fe60..c013759938 100644 --- a/test/components/converters/test_docx_file_to_document.py +++ b/test/components/converters/test_docx_file_to_document.py @@ -176,9 +176,9 @@ def test_run_with_table(self, test_files_path): table_index = next(i for i, part in enumerate(content_parts) if "| This | Is | Just a |" in part) # check that natural order of the document is preserved assert any("Donald Trump" in part for part in content_parts[:table_index]), "Text before table not found" - assert any( - "Now we are in Page 2" in part for part in content_parts[table_index + 1 :] - ), "Text after table not found" + assert any("Now we are in Page 2" in part for part in content_parts[table_index + 1 :]), ( + "Text after table not found" + ) def test_run_with_store_full_path_false(self, test_files_path): """ diff --git a/test/components/embedders/test_openai_document_embedder.py b/test/components/embedders/test_openai_document_embedder.py index 87ed6afbb6..7d43bcfa83 100644 --- a/test/components/embedders/test_openai_document_embedder.py +++ b/test/components/embedders/test_openai_document_embedder.py @@ -251,8 +251,8 @@ def test_run(self): assert len(doc.embedding) == 1536 assert all(isinstance(x, float) for x in doc.embedding) - assert ( - "text" in result["meta"]["model"] and "ada" in result["meta"]["model"] - ), "The model name does not contain 'text' and 'ada'" + assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], ( + "The model name does not contain 'text' and 'ada'" + ) assert result["meta"]["usage"] == {"prompt_tokens": 15, "total_tokens": 15}, "Usage information does not match" diff --git a/test/components/embedders/test_openai_text_embedder.py b/test/components/embedders/test_openai_text_embedder.py index 31a0360555..695e6351f0 100644 --- a/test/components/embedders/test_openai_text_embedder.py +++ b/test/components/embedders/test_openai_text_embedder.py @@ -130,8 +130,8 @@ def test_run(self): assert len(result["embedding"]) == 1536 assert all(isinstance(x, float) for x in result["embedding"]) - assert ( - "text" in result["meta"]["model"] and "ada" in result["meta"]["model"] - ), "The model name does not contain 'text' and 'ada'" + assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], ( + "The model name does not contain 'text' and 'ada'" + ) assert result["meta"]["usage"] == {"prompt_tokens": 6, "total_tokens": 6}, "Usage information does not match" diff --git a/test/components/joiners/test_document_joiner.py b/test/components/joiners/test_document_joiner.py index 6cc4f5f9e0..8160fdc48a 100644 --- a/test/components/joiners/test_document_joiner.py +++ b/test/components/joiners/test_document_joiner.py @@ -302,6 +302,6 @@ def test_test_score_norm_with_rrf(self): for i in range(len(join_results["documents"]) - 1) ) - assert ( - is_sorted - ), "Documents are not sorted in descending order by score, there is an issue with rff ranking" + assert is_sorted, ( + "Documents are not sorted in descending order by score, there is an issue with rff ranking" + ) diff --git a/test/components/preprocessors/test_document_cleaner.py b/test/components/preprocessors/test_document_cleaner.py index 5f5633a2c4..0cc929e059 100644 --- a/test/components/preprocessors/test_document_cleaner.py +++ b/test/components/preprocessors/test_document_cleaner.py @@ -71,7 +71,7 @@ def test_remove_whitespaces(self): ) assert len(result["documents"]) == 1 assert result["documents"][0].content == ( - "This is a text with some words. " "" "There is a second sentence. " "" "And there is a third sentence.\f" + "This is a text with some words. There is a second sentence. And there is a third sentence.\f" ) def test_remove_substrings(self): @@ -210,11 +210,7 @@ def test_ascii_only(self): def test_other_document_fields_are_not_lost(self): cleaner = DocumentCleaner(keep_id=True) document = Document( - content="This is a text with some words. \n" - "" - "There is a second sentence. \n" - "" - "And there is a third sentence.\n", + content="This is a text with some words. \nThere is a second sentence. \nAnd there is a third sentence.\n", dataframe=DataFrame({"col1": [1], "col2": [2]}), blob=ByteStream.from_string("some_data"), meta={"data": 1}, diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index 66d941b645..478e62d5bf 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -436,9 +436,9 @@ def test_router_with_optional_parameters(self): # Test pipeline without path parameter result = pipe.run(data={"router": {"question": "What?"}}) - assert result["router"] == { - "fallback": "What?" - }, "Default route should work in pipeline when 'path' is not provided" + assert result["router"] == {"fallback": "What?"}, ( + "Default route should work in pipeline when 'path' is not provided" + ) # Test pipeline with path parameter result = pipe.run(data={"router": {"question": "What?", "path": "followup_short"}}) diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index 8f07dfec99..652fea7c30 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -823,7 +823,7 @@ def pipeline_that_has_a_component_with_only_default_inputs(): "answers": [ GeneratedAnswer( data="Paris", - query="What " "is " "the " "capital " "of " "France?", + query="What is the capital of France?", documents=[ Document( id="413dccdf51a54cca75b7ed2eddac04e6e58560bd2f0caf4106a3efc023fe3651", @@ -916,7 +916,7 @@ def fake_generator_run(self, generation_kwargs: Optional[Dict[str, Any]] = None, pipe, [ PipelineRunData( - inputs={"prompt_builder": {"query": "What is the capital of " "Italy?"}}, + inputs={"prompt_builder": {"query": "What is the capital of Italy?"}}, expected_outputs={"router": {"correct_replies": ["Rome"]}}, expected_run_order=["prompt_builder", "generator", "router", "prompt_builder", "generator", "router"], ) From dd9660f90d8cd074ac420139e0f78fa3970b162e Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Thu, 9 Jan 2025 20:12:10 +0100 Subject: [PATCH 02/41] fix: PyPDFToDocument initializes documents with content and meta (#8698) * initialize document with content and meta * update test * add test checking that not only content is used for id generation --- haystack/components/converters/pypdf.py | 10 +++++----- releasenotes/notes/pypdf-docid-293dac08ea5f8491.yaml | 4 ++++ test/components/converters/test_pypdf_to_document.py | 8 ++++++-- 3 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/pypdf-docid-293dac08ea5f8491.yaml diff --git a/haystack/components/converters/pypdf.py b/haystack/components/converters/pypdf.py index 19a4e2e453..334ef097d7 100644 --- a/haystack/components/converters/pypdf.py +++ b/haystack/components/converters/pypdf.py @@ -155,7 +155,7 @@ def from_dict(cls, data): """ return default_from_dict(cls, data) - def _default_convert(self, reader: "PdfReader") -> Document: + def _default_convert(self, reader: "PdfReader") -> str: texts = [] for page in reader.pages: texts.append( @@ -170,7 +170,7 @@ def _default_convert(self, reader: "PdfReader") -> Document: ) ) text = "\f".join(texts) - return Document(content=text) + return text @component.output_types(documents=List[Document]) def run( @@ -205,14 +205,14 @@ def run( continue try: pdf_reader = PdfReader(io.BytesIO(bytestream.data)) - document = self._default_convert(pdf_reader) + text = self._default_convert(pdf_reader) except Exception as e: logger.warning( "Could not read {source} and convert it to Document, skipping. {error}", source=source, error=e ) continue - if document.content is None or document.content.strip() == "": + if text is None or text.strip() == "": logger.warning( "PyPDFToDocument could not extract text from the file {source}. Returning an empty document.", source=source, @@ -222,7 +222,7 @@ def run( if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) - document.meta = merged_metadata + document = Document(content=text, meta=merged_metadata) documents.append(document) return {"documents": documents} diff --git a/releasenotes/notes/pypdf-docid-293dac08ea5f8491.yaml b/releasenotes/notes/pypdf-docid-293dac08ea5f8491.yaml new file mode 100644 index 0000000000..f077d8b4ee --- /dev/null +++ b/releasenotes/notes/pypdf-docid-293dac08ea5f8491.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + PyPDFToDocument now creates documents with id based on converted text and meta data. Before it didn't take the meta data into account. diff --git a/test/components/converters/test_pypdf_to_document.py b/test/components/converters/test_pypdf_to_document.py index fa8f295db7..916bb771ee 100644 --- a/test/components/converters/test_pypdf_to_document.py +++ b/test/components/converters/test_pypdf_to_document.py @@ -113,8 +113,8 @@ def test_default_convert(self): layout_mode_font_height_weight=1.5, ) - doc = converter._default_convert(mock_reader) - assert doc.content == "Page 1 content\fPage 2 content" + text = converter._default_convert(mock_reader) + assert text == "Page 1 content\fPage 2 content" expected_params = { "extraction_mode": "layout", @@ -209,3 +209,7 @@ def test_run_empty_document(self, caplog, test_files_path): output = PyPDFToDocument().run(sources=paths) assert "PyPDFToDocument could not extract text from the file" in caplog.text assert output["documents"][0].content == "" + + # Check that meta is used when the returned document is initialized and thus when doc id is generated + assert output["documents"][0].meta["file_path"] == "non_text_searchable.pdf" + assert output["documents"][0].id != Document(content="").id From 08cf09f83fa0d2749e6157c1bc575c8a25133bd8 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 10 Jan 2025 12:15:15 +0100 Subject: [PATCH 03/41] refactor: `create_tool_from_function` + `tool` decorator (#8697) * create_tool_from_function + decorator * release note * improve usage example * add imports to @tool usage example * clarify docstrings * small docstring addition --- docs/pydoc/config/tools_api.yml | 2 +- haystack/tools/__init__.py | 3 +- haystack/tools/errors.py | 19 ++ haystack/tools/from_function.py | 166 +++++++++++++ haystack/tools/tool.py | 129 +--------- ...e-tool-from-function-b4b318574f3fb3a0.yaml | 6 + test/tools/test_from_function.py | 231 ++++++++++++++++++ test/tools/test_tool.py | 190 +------------- 8 files changed, 427 insertions(+), 319 deletions(-) create mode 100644 haystack/tools/errors.py create mode 100644 haystack/tools/from_function.py create mode 100644 releasenotes/notes/create-tool-from-function-b4b318574f3fb3a0.yaml create mode 100644 test/tools/test_from_function.py diff --git a/docs/pydoc/config/tools_api.yml b/docs/pydoc/config/tools_api.yml index 35aa7aeff8..d3f953087f 100644 --- a/docs/pydoc/config/tools_api.yml +++ b/docs/pydoc/config/tools_api.yml @@ -2,7 +2,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/tools] modules: - ["tool"] + ["tool", "from_function"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/tools/__init__.py b/haystack/tools/__init__.py index 9cd887f4e2..4601ac71c6 100644 --- a/haystack/tools/__init__.py +++ b/haystack/tools/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from haystack.tools.from_function import create_tool_from_function, tool from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace -__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace"] +__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace", "create_tool_from_function", "tool"] diff --git a/haystack/tools/errors.py b/haystack/tools/errors.py new file mode 100644 index 0000000000..6080287d64 --- /dev/null +++ b/haystack/tools/errors.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +class SchemaGenerationError(Exception): + """ + Exception raised when automatic schema generation fails. + """ + + pass + + +class ToolInvocationError(Exception): + """ + Exception raised when a Tool invocation fails. + """ + + pass diff --git a/haystack/tools/from_function.py b/haystack/tools/from_function.py new file mode 100644 index 0000000000..67fd476207 --- /dev/null +++ b/haystack/tools/from_function.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from typing import Any, Callable, Dict, Optional + +from pydantic import create_model + +from haystack.tools.errors import SchemaGenerationError +from haystack.tools.tool import Tool + + +def create_tool_from_function( + function: Callable, name: Optional[str] = None, description: Optional[str] = None +) -> "Tool": + """ + Create a Tool instance from a function. + + Allows customizing the Tool name and description. + For simpler use cases, consider using the `@tool` decorator. + + ### Usage example + + ```python + from typing import Annotated, Literal + from haystack.tools import create_tool_from_function + + def get_weather( + city: Annotated[str, "the city for which to get the weather"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"): + '''A simple function to get the current weather for a location.''' + return f"Weather report for {city}: 20 {unit}, sunny" + + tool = create_tool_from_function(get_weather) + + print(tool) + >>> Tool(name='get_weather', description='A simple function to get the current weather for a location.', + >>> parameters={ + >>> 'type': 'object', + >>> 'properties': { + >>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'}, + >>> 'unit': { + >>> 'type': 'string', + >>> 'enum': ['Celsius', 'Fahrenheit'], + >>> 'description': 'the unit for the temperature', + >>> 'default': 'Celsius', + >>> }, + >>> } + >>> }, + >>> function=) + ``` + + :param function: + The function to be converted into a Tool. + The function must include type hints for all parameters. + The function is expected to have basic python input types (str, int, float, bool, list, dict, tuple). + Other input types may work but are not guaranteed. + If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description. + :param name: + The name of the Tool. If not provided, the name of the function will be used. + :param description: + The description of the Tool. If not provided, the docstring of the function will be used. + To intentionally leave the description empty, pass an empty string. + + :returns: + The Tool created from the function. + + :raises ValueError: + If any parameter of the function lacks a type hint. + :raises SchemaGenerationError: + If there is an error generating the JSON schema for the Tool. + """ + + tool_description = description if description is not None else (function.__doc__ or "") + + signature = inspect.signature(function) + + # collect fields (types and defaults) and descriptions from function parameters + fields: Dict[str, Any] = {} + descriptions = {} + + for param_name, param in signature.parameters.items(): + if param.annotation is param.empty: + raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.") + + # if the parameter has not a default value, Pydantic requires an Ellipsis (...) + # to explicitly indicate that the parameter is required + default = param.default if param.default is not param.empty else ... + fields[param_name] = (param.annotation, default) + + if hasattr(param.annotation, "__metadata__"): + descriptions[param_name] = param.annotation.__metadata__[0] + + # create Pydantic model and generate JSON schema + try: + model = create_model(function.__name__, **fields) + schema = model.model_json_schema() + except Exception as e: + raise SchemaGenerationError(f"Failed to create JSON schema for function '{function.__name__}'") from e + + # we don't want to include title keywords in the schema, as they contain redundant information + # there is no programmatic way to prevent Pydantic from adding them, so we remove them later + # see https://github.com/pydantic/pydantic/discussions/8504 + _remove_title_from_schema(schema) + + # add parameters descriptions to the schema + for param_name, param_description in descriptions.items(): + if param_name in schema["properties"]: + schema["properties"][param_name]["description"] = param_description + + return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) + + +def tool(function: Callable) -> Tool: + """ + Decorator to convert a function into a Tool. + + Tool name, description, and parameters are inferred from the function. + If you need to customize more the Tool, use `create_tool_from_function` instead. + + ### Usage example + ```python + from typing import Annotated, Literal + from haystack.tools import tool + + @tool + def get_weather( + city: Annotated[str, "the city for which to get the weather"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"): + '''A simple function to get the current weather for a location.''' + return f"Weather report for {city}: 20 {unit}, sunny" + + print(get_weather) + >>> Tool(name='get_weather', description='A simple function to get the current weather for a location.', + >>> parameters={ + >>> 'type': 'object', + >>> 'properties': { + >>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'}, + >>> 'unit': { + >>> 'type': 'string', + >>> 'enum': ['Celsius', 'Fahrenheit'], + >>> 'description': 'the unit for the temperature', + >>> 'default': 'Celsius', + >>> }, + >>> } + >>> }, + >>> function=) + ``` + """ + return create_tool_from_function(function) + + +def _remove_title_from_schema(schema: Dict[str, Any]): + """ + Remove the 'title' keyword from JSON schema and contained property schemas. + + :param schema: + The JSON schema to remove the 'title' keyword from. + """ + schema.pop("title", None) + + for property_schema in schema["properties"].values(): + for key in list(property_schema.keys()): + if key == "title": + del property_schema[key] diff --git a/haystack/tools/tool.py b/haystack/tools/tool.py index 3b3e541031..bdb8f005b6 100644 --- a/haystack/tools/tool.py +++ b/haystack/tools/tool.py @@ -2,14 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -import inspect from dataclasses import asdict, dataclass from typing import Any, Callable, Dict, List, Optional -from pydantic import create_model - from haystack.core.serialization import generate_qualified_class_name, import_class_by_name from haystack.lazy_imports import LazyImport +from haystack.tools.errors import ToolInvocationError from haystack.utils import deserialize_callable, serialize_callable with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: @@ -17,22 +15,6 @@ from jsonschema.exceptions import SchemaError -class ToolInvocationError(Exception): - """ - Exception raised when a Tool invocation fails. - """ - - pass - - -class SchemaGenerationError(Exception): - """ - Exception raised when automatic schema generation fails. - """ - - pass - - @dataclass class Tool: """ @@ -108,115 +90,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "Tool": init_parameters["function"] = deserialize_callable(init_parameters["function"]) return cls(**init_parameters) - @classmethod - def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool": - """ - Create a Tool instance from a function. - - ### Usage example - - ```python - from typing import Annotated, Literal - from haystack.dataclasses import Tool - - def get_weather( - city: Annotated[str, "the city for which to get the weather"] = "Munich", - unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"): - '''A simple function to get the current weather for a location.''' - return f"Weather report for {city}: 20 {unit}, sunny" - - tool = Tool.from_function(get_weather) - - print(tool) - >>> Tool(name='get_weather', description='A simple function to get the current weather for a location.', - >>> parameters={ - >>> 'type': 'object', - >>> 'properties': { - >>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'}, - >>> 'unit': { - >>> 'type': 'string', - >>> 'enum': ['Celsius', 'Fahrenheit'], - >>> 'description': 'the unit for the temperature', - >>> 'default': 'Celsius', - >>> }, - >>> } - >>> }, - >>> function=) - ``` - - :param function: - The function to be converted into a Tool. - The function must include type hints for all parameters. - If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description. - :param name: - The name of the Tool. If not provided, the name of the function will be used. - :param description: - The description of the Tool. If not provided, the docstring of the function will be used. - To intentionally leave the description empty, pass an empty string. - - :returns: - The Tool created from the function. - - :raises ValueError: - If any parameter of the function lacks a type hint. - :raises SchemaGenerationError: - If there is an error generating the JSON schema for the Tool. - """ - - tool_description = description if description is not None else (function.__doc__ or "") - - signature = inspect.signature(function) - - # collect fields (types and defaults) and descriptions from function parameters - fields: Dict[str, Any] = {} - descriptions = {} - - for param_name, param in signature.parameters.items(): - if param.annotation is param.empty: - raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.") - - # if the parameter has not a default value, Pydantic requires an Ellipsis (...) - # to explicitly indicate that the parameter is required - default = param.default if param.default is not param.empty else ... - fields[param_name] = (param.annotation, default) - - if hasattr(param.annotation, "__metadata__"): - descriptions[param_name] = param.annotation.__metadata__[0] - - # create Pydantic model and generate JSON schema - try: - model = create_model(function.__name__, **fields) - schema = model.model_json_schema() - except Exception as e: - raise SchemaGenerationError(f"Failed to create JSON schema for function '{function.__name__}'") from e - - # we don't want to include title keywords in the schema, as they contain redundant information - # there is no programmatic way to prevent Pydantic from adding them, so we remove them later - # see https://github.com/pydantic/pydantic/discussions/8504 - _remove_title_from_schema(schema) - - # add parameters descriptions to the schema - for param_name, param_description in descriptions.items(): - if param_name in schema["properties"]: - schema["properties"][param_name]["description"] = param_description - - return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) - - -def _remove_title_from_schema(schema: Dict[str, Any]): - """ - Remove the 'title' keyword from JSON schema and contained property schemas. - - :param schema: - The JSON schema to remove the 'title' keyword from. - """ - schema.pop("title", None) - - for property_schema in schema["properties"].values(): - for key in list(property_schema.keys()): - if key == "title": - del property_schema[key] - def _check_duplicate_tool_names(tools: Optional[List[Tool]]) -> None: """ diff --git a/releasenotes/notes/create-tool-from-function-b4b318574f3fb3a0.yaml b/releasenotes/notes/create-tool-from-function-b4b318574f3fb3a0.yaml new file mode 100644 index 0000000000..d13595fed2 --- /dev/null +++ b/releasenotes/notes/create-tool-from-function-b4b318574f3fb3a0.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added a new `create_tool_from_function` function to create a `Tool` instance from a function, with automatic + generation of name, description and parameters. + Added a `tool` decorator to achieve the same result. diff --git a/test/tools/test_from_function.py b/test/tools/test_from_function.py new file mode 100644 index 0000000000..4516a78a68 --- /dev/null +++ b/test/tools/test_from_function.py @@ -0,0 +1,231 @@ +import pytest + +from haystack.tools.from_function import create_tool_from_function, _remove_title_from_schema, tool +from haystack.tools.errors import SchemaGenerationError +from typing import Literal, Optional + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + + +def function_with_docstring(city: str) -> str: + """Get weather report for a city.""" + return f"Weather report for {city}: 20°C, sunny" + + +def test_from_function_description_from_docstring(): + tool = create_tool_from_function(function=function_with_docstring) + + assert tool.name == "function_with_docstring" + assert tool.description == "Get weather report for a city." + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + +def test_from_function_with_empty_description(): + tool = create_tool_from_function(function=function_with_docstring, description="") + + assert tool.name == "function_with_docstring" + assert tool.description == "" + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + +def test_from_function_with_custom_description(): + tool = create_tool_from_function(function=function_with_docstring, description="custom description") + + assert tool.name == "function_with_docstring" + assert tool.description == "custom description" + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + +def test_from_function_with_custom_name(): + tool = create_tool_from_function(function=function_with_docstring, name="custom_name") + + assert tool.name == "custom_name" + assert tool.description == "Get weather report for a city." + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + +def test_from_function_annotated(): + def function_with_annotations( + city: Annotated[str, "the city for which to get the weather"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius", + nullable_param: Annotated[Optional[str], "a nullable parameter"] = None, + ) -> str: + """A simple function to get the current weather for a location.""" + return f"Weather report for {city}: 20 {unit}, sunny" + + tool = create_tool_from_function(function=function_with_annotations) + + assert tool.name == "function_with_annotations" + assert tool.description == "A simple function to get the current weather for a location." + assert tool.parameters == { + "type": "object", + "properties": { + "city": {"type": "string", "description": "the city for which to get the weather", "default": "Munich"}, + "unit": { + "type": "string", + "enum": ["Celsius", "Fahrenheit"], + "description": "the unit for the temperature", + "default": "Celsius", + }, + "nullable_param": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "description": "a nullable parameter", + "default": None, + }, + }, + } + + +def test_from_function_missing_type_hint(): + def function_missing_type_hint(city) -> str: + return f"Weather report for {city}: 20°C, sunny" + + with pytest.raises(ValueError): + create_tool_from_function(function=function_missing_type_hint) + + +def test_from_function_schema_generation_error(): + def function_with_invalid_type_hint(city: "invalid") -> str: + return f"Weather report for {city}: 20°C, sunny" + + with pytest.raises(SchemaGenerationError): + create_tool_from_function(function=function_with_invalid_type_hint) + + +def test_tool_decorator(): + @tool + def get_weather(city: str) -> str: + """Get weather report for a city.""" + return f"Weather report for {city}: 20°C, sunny" + + assert get_weather.name == "get_weather" + assert get_weather.description == "Get weather report for a city." + assert get_weather.parameters == { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + assert callable(get_weather.function) + assert get_weather.function("Berlin") == "Weather report for Berlin: 20°C, sunny" + + +def test_tool_decorator_with_annotated_params(): + @tool + def get_weather( + city: Annotated[str, "The target city"] = "Berlin", + format: Annotated[Literal["short", "long"], "Output format"] = "short", + ) -> str: + """Get weather report for a city.""" + return f"Weather report for {city} ({format} format): 20°C, sunny" + + assert get_weather.name == "get_weather" + assert get_weather.description == "Get weather report for a city." + assert get_weather.parameters == { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The target city", "default": "Berlin"}, + "format": {"type": "string", "enum": ["short", "long"], "description": "Output format", "default": "short"}, + }, + } + assert callable(get_weather.function) + assert get_weather.function("Berlin", "short") == "Weather report for Berlin (short format): 20°C, sunny" + + +def test_remove_title_from_schema(): + complex_schema = { + "properties": { + "parameter1": { + "anyOf": [{"type": "string"}, {"type": "integer"}], + "default": "default_value", + "title": "Parameter1", + }, + "parameter2": { + "default": [1, 2, 3], + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "title": "Parameter2", + "type": "array", + }, + "parameter3": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, + ], + "default": 42, + "title": "Parameter3", + }, + "parameter4": { + "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], + "default": {"key": "value"}, + "title": "Parameter4", + }, + }, + "title": "complex_function", + "type": "object", + } + + _remove_title_from_schema(complex_schema) + + assert complex_schema == { + "properties": { + "parameter1": {"anyOf": [{"type": "string"}, {"type": "integer"}], "default": "default_value"}, + "parameter2": { + "default": [1, 2, 3], + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + }, + "parameter3": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, + ], + "default": 42, + }, + "parameter4": { + "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], + "default": {"key": "value"}, + }, + }, + "type": "object", + } + + +def test_remove_title_from_schema_do_not_remove_title_property(): + """Test that the utility function only removes the 'title' keywords and not the 'title' property (if present).""" + schema = { + "properties": { + "parameter1": {"type": "string", "title": "Parameter1"}, + "title": {"type": "string", "title": "Title"}, + }, + "title": "complex_function", + "type": "object", + } + + _remove_title_from_schema(schema) + + assert schema == {"properties": {"parameter1": {"type": "string"}, "title": {"type": "string"}}, "type": "object"} + + +def test_remove_title_from_schema_handle_no_title_in_top_level(): + schema = { + "properties": { + "parameter1": {"type": "string", "title": "Parameter1"}, + "parameter2": {"type": "integer", "title": "Parameter2"}, + }, + "type": "object", + } + + _remove_title_from_schema(schema) + + assert schema == { + "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, + "type": "object", + } diff --git a/test/tools/test_tool.py b/test/tools/test_tool.py index f752395fd6..43ed42044a 100644 --- a/test/tools/test_tool.py +++ b/test/tools/test_tool.py @@ -2,22 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Literal, Optional - import pytest -from haystack.tools.tool import ( - SchemaGenerationError, - Tool, - ToolInvocationError, - _remove_title_from_schema, - deserialize_tools_inplace, - _check_duplicate_tool_names, -) -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated +from haystack.tools.tool import Tool, ToolInvocationError, deserialize_tools_inplace, _check_duplicate_tool_names def get_weather_report(city: str) -> str: @@ -27,11 +14,6 @@ def get_weather_report(city: str) -> str: parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} -def function_with_docstring(city: str) -> str: - """Get weather report for a city.""" - return f"Weather report for {city}: 20°C, sunny" - - class TestTool: def test_init(self): tool = Tool( @@ -104,83 +86,6 @@ def test_from_dict(self): assert tool.parameters == parameters assert tool.function == get_weather_report - def test_from_function_description_from_docstring(self): - tool = Tool.from_function(function=function_with_docstring) - - assert tool.name == "function_with_docstring" - assert tool.description == "Get weather report for a city." - assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} - assert tool.function == function_with_docstring - - def test_from_function_with_empty_description(self): - tool = Tool.from_function(function=function_with_docstring, description="") - - assert tool.name == "function_with_docstring" - assert tool.description == "" - assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} - assert tool.function == function_with_docstring - - def test_from_function_with_custom_description(self): - tool = Tool.from_function(function=function_with_docstring, description="custom description") - - assert tool.name == "function_with_docstring" - assert tool.description == "custom description" - assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} - assert tool.function == function_with_docstring - - def test_from_function_with_custom_name(self): - tool = Tool.from_function(function=function_with_docstring, name="custom_name") - - assert tool.name == "custom_name" - assert tool.description == "Get weather report for a city." - assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} - assert tool.function == function_with_docstring - - def test_from_function_missing_type_hint(self): - def function_missing_type_hint(city) -> str: - return f"Weather report for {city}: 20°C, sunny" - - with pytest.raises(ValueError): - Tool.from_function(function=function_missing_type_hint) - - def test_from_function_schema_generation_error(self): - def function_with_invalid_type_hint(city: "invalid") -> str: - return f"Weather report for {city}: 20°C, sunny" - - with pytest.raises(SchemaGenerationError): - Tool.from_function(function=function_with_invalid_type_hint) - - def test_from_function_annotated(self): - def function_with_annotations( - city: Annotated[str, "the city for which to get the weather"] = "Munich", - unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius", - nullable_param: Annotated[Optional[str], "a nullable parameter"] = None, - ) -> str: - """A simple function to get the current weather for a location.""" - return f"Weather report for {city}: 20 {unit}, sunny" - - tool = Tool.from_function(function=function_with_annotations) - - assert tool.name == "function_with_annotations" - assert tool.description == "A simple function to get the current weather for a location." - assert tool.parameters == { - "type": "object", - "properties": { - "city": {"type": "string", "description": "the city for which to get the weather", "default": "Munich"}, - "unit": { - "type": "string", - "enum": ["Celsius", "Fahrenheit"], - "description": "the unit for the temperature", - "default": "Celsius", - }, - "nullable_param": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "description": "a nullable parameter", - "default": None, - }, - }, - } - def test_deserialize_tools_inplace(): tool = Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report) @@ -221,99 +126,6 @@ def test_deserialize_tools_inplace_failures(): deserialize_tools_inplace(data) -def test_remove_title_from_schema(): - complex_schema = { - "properties": { - "parameter1": { - "anyOf": [{"type": "string"}, {"type": "integer"}], - "default": "default_value", - "title": "Parameter1", - }, - "parameter2": { - "default": [1, 2, 3], - "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, - "title": "Parameter2", - "type": "array", - }, - "parameter3": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"}, - {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, - ], - "default": 42, - "title": "Parameter3", - }, - "parameter4": { - "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], - "default": {"key": "value"}, - "title": "Parameter4", - }, - }, - "title": "complex_function", - "type": "object", - } - - _remove_title_from_schema(complex_schema) - - assert complex_schema == { - "properties": { - "parameter1": {"anyOf": [{"type": "string"}, {"type": "integer"}], "default": "default_value"}, - "parameter2": { - "default": [1, 2, 3], - "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, - "type": "array", - }, - "parameter3": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"}, - {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, - ], - "default": 42, - }, - "parameter4": { - "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], - "default": {"key": "value"}, - }, - }, - "type": "object", - } - - -def test_remove_title_from_schema_do_not_remove_title_property(): - """Test that the utility function only removes the 'title' keywords and not the 'title' property (if present).""" - schema = { - "properties": { - "parameter1": {"type": "string", "title": "Parameter1"}, - "title": {"type": "string", "title": "Title"}, - }, - "title": "complex_function", - "type": "object", - } - - _remove_title_from_schema(schema) - - assert schema == {"properties": {"parameter1": {"type": "string"}, "title": {"type": "string"}}, "type": "object"} - - -def test_remove_title_from_schema_handle_no_title_in_top_level(): - schema = { - "properties": { - "parameter1": {"type": "string", "title": "Parameter1"}, - "parameter2": {"type": "integer", "title": "Parameter2"}, - }, - "type": "object", - } - - _remove_title_from_schema(schema) - - assert schema == { - "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, - "type": "object", - } - - def test_check_duplicate_tool_names(): tools = [ Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), From 741ce5df5053aa5299876d9466f0a720e2aedbb2 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 10 Jan 2025 14:46:41 +0100 Subject: [PATCH 04/41] fix: `OpenAIChatGenerator` - do not pass tools to the OpenAI client when none are provided (#8702) * do not pass tools to OpenAI client if None * release note * fix release note --- haystack/components/generators/chat/openai.py | 7 ++++--- ...-pass-tools-to-openai-if-none-1fe09e924e7fad7a.yaml | 6 ++++++ test/components/generators/chat/test_openai.py | 10 +++++++++- 3 files changed, 19 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/do-not-pass-tools-to-openai-if-none-1fe09e924e7fad7a.yaml diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 251b0b741c..0b699e3bc1 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -286,12 +286,13 @@ def _prepare_api_call( # noqa: PLR0913 tools_strict = tools_strict if tools_strict is not None else self.tools_strict _check_duplicate_tool_names(tools) - openai_tools = None + openai_tools = {} if tools: - openai_tools = [ + tool_definitions = [ {"type": "function", "function": {**t.tool_spec, **({"strict": tools_strict} if tools_strict else {})}} for t in tools ] + openai_tools = {"tools": tool_definitions} is_streaming = streaming_callback is not None num_responses = generation_kwargs.pop("n", 1) @@ -302,8 +303,8 @@ def _prepare_api_call( # noqa: PLR0913 "model": self.model, "messages": openai_formatted_messages, # type: ignore[arg-type] # openai expects list of specific message types "stream": streaming_callback is not None, - "tools": openai_tools, # type: ignore[arg-type] "n": num_responses, + **openai_tools, **generation_kwargs, } diff --git a/releasenotes/notes/do-not-pass-tools-to-openai-if-none-1fe09e924e7fad7a.yaml b/releasenotes/notes/do-not-pass-tools-to-openai-if-none-1fe09e924e7fad7a.yaml new file mode 100644 index 0000000000..aa309098fa --- /dev/null +++ b/releasenotes/notes/do-not-pass-tools-to-openai-if-none-1fe09e924e7fad7a.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + OpenAIChatGenerator no longer passes tools to the OpenAI client if none are provided. + Previously, a null value was passed. + This change improves compatibility with OpenAI-compatible APIs that do not support tools. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 8333608ea6..eb50d92739 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -293,6 +293,9 @@ def test_run_with_params(self, chat_messages, openai_mock_chat_completion): assert kwargs["max_tokens"] == 10 assert kwargs["temperature"] == 0.5 + # check that the tools are not passed to the OpenAI API (the generator is initialized without tools) + assert "tools" not in kwargs + # check that the component returns the correct response assert isinstance(response, dict) assert "replies" in response @@ -400,9 +403,14 @@ def test_run_with_tools(self, tools): mock_chat_completion_create.return_value = completion - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools, tools_strict=True) response = component.run([ChatMessage.from_user("What's the weather like in Paris?")]) + # ensure that the tools are passed to the OpenAI API + assert mock_chat_completion_create.call_args[1]["tools"] == [ + {"type": "function", "function": {**tools[0].tool_spec, "strict": True}} + ] + assert len(response["replies"]) == 1 message = response["replies"][0] From 4f73b192f8bcab8077ef7adac7063b23c237f630 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 10 Jan 2025 17:28:53 +0100 Subject: [PATCH 05/41] feat: add `RecursiveSplitter` component for `Document` preprocessing (#8605) * initial import * adding initial version + tests * adding more tests * more tests * incorporating SentenceSplitter based on NLTK * adding more tests * adding release notes * adding LICENSE header * removing unused imports * fixing example docstring * addding docstrings * fixing tests and returning a dictionary * updating release notes * attending PR comments * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * wip: updating tests for split_idx_start and _split_overlap * adding tests for split_idx and split_start and overlaps * adjusting file for LICENSE checking * adding more tests * adding tests for page numbering * adding tests for min split lenghts and falling back to character-level chunking based on size * fixing linting issue * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * wip * wip * updating tests * wip: fixing all tests after changes * more tests * wip: debugging sentence overlap * wip: debugging page number * wip * wip; fixed bug with sentence tokenizer, needs to keep white spaces * adding tests for counting pages on different split approaches * NLTK checks done on SentenceSplitter * fixing types * adding detecting for full overlap with previous chunks * fixing types * improving docstring * improving docstring * adding custom lenght, 'character' use case * customising overlap function for word and adding a few tests * updating docstring * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * wip: adding more tests for word unit length * fix * feat: `Tool` dataclass - unified abstraction to represent tools (#8652) * draft * del HF token in tests * adaptations * progress * fix type * import sorting * more control on deserialization * release note * improvements * support name field * fix chatpromptbuilder test * port Tool from experimental * release note * docs upd * Update tool.py --------- Co-authored-by: Daria Fokina * fix: fix deserialization issues in multi-threading environments (#8651) * adding 'word' as default length * fixing types * handing both default strategies * wip * \f was not being counted properly * updating tests * fixing the overlap bug * adding more tests * refactoring _apply_overlap * further refactoring * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * Update haystack/components/preprocessors/recursive_splitter.py Co-authored-by: Sebastian Husch Lee * adding ticks to close code block * fixing comments * applying changes: split with space and force keep_white_spaces=True * fixing some tests and replacing count words approach in more places * keep_white_spaces = True only if not defined * cleaning docs * handling some more edge cases, when split is still too big and all separators ran * fixing fallback whitespaces count to fixed word/char split based on split size * cleaning --------- Co-authored-by: Sebastian Husch Lee Co-authored-by: Stefano Fiorucci Co-authored-by: Daria Fokina Co-authored-by: Tobias Wochinger --- haystack/components/preprocessors/__init__.py | 4 +- .../preprocessors/recursive_splitter.py | 421 +++++++++ .../preprocessors/sentence_tokenizer.py | 1 + ...g-recursive-splitter-1fa716fdd77d4d8c.yaml | 4 + .../preprocessors/test_recursive_splitter.py | 818 ++++++++++++++++++ 5 files changed, 1246 insertions(+), 2 deletions(-) create mode 100644 haystack/components/preprocessors/recursive_splitter.py create mode 100644 releasenotes/notes/adding-recursive-splitter-1fa716fdd77d4d8c.yaml create mode 100644 test/components/preprocessors/test_recursive_splitter.py diff --git a/haystack/components/preprocessors/__init__.py b/haystack/components/preprocessors/__init__.py index 467f16ceeb..33e446e8a6 100644 --- a/haystack/components/preprocessors/__init__.py +++ b/haystack/components/preprocessors/__init__.py @@ -5,7 +5,7 @@ from .document_cleaner import DocumentCleaner from .document_splitter import DocumentSplitter from .nltk_document_splitter import NLTKDocumentSplitter -from .sentence_tokenizer import SentenceSplitter +from .recursive_splitter import RecursiveDocumentSplitter from .text_cleaner import TextCleaner -__all__ = ["DocumentSplitter", "DocumentCleaner", "NLTKDocumentSplitter", "SentenceSplitter", "TextCleaner"] +__all__ = ["DocumentSplitter", "DocumentCleaner", "RecursiveDocumentSplitter", "TextCleaner", "NLTKDocumentSplitter"] diff --git a/haystack/components/preprocessors/recursive_splitter.py b/haystack/components/preprocessors/recursive_splitter.py new file mode 100644 index 0000000000..3286a80d72 --- /dev/null +++ b/haystack/components/preprocessors/recursive_splitter.py @@ -0,0 +1,421 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import re +from copy import deepcopy +from typing import Any, Dict, List, Literal, Optional, Tuple + +from haystack import Document, component, logging + +logger = logging.getLogger(__name__) + + +@component +class RecursiveDocumentSplitter: + """ + Recursively chunk text into smaller chunks. + + This component is used to split text into smaller chunks, it does so by recursively applying a list of separators + to the text. + + The separators are applied in the order they are provided, typically this is a list of separators that are + applied in a specific order, being the last separator the most specific one. + + Each separator is applied to the text, it then checks each of the resulting chunks, it keeps the chunks that + are within the chunk_size, for the ones that are larger than the chunk_size, it applies the next separator in the + list to the remaining text. + + This is done until all chunks are smaller than the chunk_size parameter. + + Example: + + ```python + from haystack import Document + from haystack.components.preprocessors import RecursiveDocumentSplitter + + chunker = RecursiveDocumentSplitter(split_length=260, split_overlap=0, separators=["\n\n", "\n", ".", " "]) + text = '''Artificial intelligence (AI) - Introduction + + AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems. + AI technology is widely used throughout industry, government, and science. Some high-profile applications include advanced web search engines; recommendation systems; interacting via human speech; autonomous vehicles; generative and creative tools; and superhuman play and analysis in strategy games.''' + chunker.warm_up() + doc = Document(content=text) + doc_chunks = chunker.run([doc]) + print(doc_chunks["documents"]) + >[ + >Document(id=..., content: 'Artificial intelligence (AI) - Introduction\n\n', meta: {'original_id': '65167a9823dd883de577e828ca4fd529e6f7241f0ff616acfce454d808478951', 'split_id': 0, 'split_idx_start': 0, '_split_overlap': []}) + >Document(id=..., content: 'AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\n', meta: {'original_id': '65167a9823dd883de577e828ca4fd529e6f7241f0ff616acfce454d808478951', 'split_id': 1, 'split_idx_start': 45, '_split_overlap': []}) + >Document(id=..., content: 'AI technology is widely used throughout industry, government, and science.', meta: {'original_id': '65167a9823dd883de577e828ca4fd529e6f7241f0ff616acfce454d808478951', 'split_id': 2, 'split_idx_start': 142, '_split_overlap': []}) + >Document(id=..., content: ' Some high-profile applications include advanced web search engines; recommendation systems; interac...', meta: {'original_id': '65167a9823dd883de577e828ca4fd529e6f7241f0ff616acfce454d808478951', 'split_id': 3, 'split_idx_start': 216, '_split_overlap': []}) + >] + ``` + """ # noqa: E501 + + def __init__( + self, + *, + split_length: int = 200, + split_overlap: int = 0, + split_unit: Literal["word", "char"] = "word", + separators: Optional[List[str]] = None, + sentence_splitter_params: Optional[Dict[str, Any]] = None, + ): + """ + Initializes a RecursiveDocumentSplitter. + + :param split_length: The maximum length of each chunk by default in words, but can be in characters. + See the `split_units` parameter. + :param split_overlap: The number of characters to overlap between consecutive chunks. + :param split_unit: The unit of the split_length parameter. It can be either "word" or "char". + :param separators: An optional list of separator strings to use for splitting the text. The string + separators will be treated as regular expressions unless the separator is "sentence", in that case the + text will be split into sentences using a custom sentence tokenizer based on NLTK. + See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter. + If no separators are provided, the default separators ["\n\n", "sentence", "\n", " "] are used. + :param sentence_splitter_params: Optional parameters to pass to the sentence tokenizer. + See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter for more information. + + :raises ValueError: If the overlap is greater than or equal to the chunk size or if the overlap is negative, or + if any separator is not a string. + """ + self.split_length = split_length + self.split_overlap = split_overlap + self.split_units = split_unit + self.separators = separators if separators else ["\n\n", "sentence", "\n", " "] # default separators + self._check_params() + self.nltk_tokenizer = None + self.sentence_splitter_params = ( + {"keep_white_spaces": True} if sentence_splitter_params is None else sentence_splitter_params + ) + + def warm_up(self) -> None: + """ + Warm up the sentence tokenizer. + """ + self.nltk_tokenizer = self._get_custom_sentence_tokenizer(self.sentence_splitter_params) + + def _check_params(self) -> None: + if self.split_length < 1: + raise ValueError("Split length must be at least 1 character.") + if self.split_overlap < 0: + raise ValueError("Overlap must be greater than zero.") + if self.split_overlap >= self.split_length: + raise ValueError("Overlap cannot be greater than or equal to the chunk size.") + if not all(isinstance(separator, str) for separator in self.separators): + raise ValueError("All separators must be strings.") + + @staticmethod + def _get_custom_sentence_tokenizer(sentence_splitter_params: Dict[str, Any]): + from haystack.components.preprocessors.sentence_tokenizer import SentenceSplitter + + return SentenceSplitter(**sentence_splitter_params) + + def _split_chunk(self, current_chunk: str) -> Tuple[str, str]: + """ + Splits a chunk based on the split_length and split_units attribute. + + :param current_chunk: The current chunk to be split. + :returns: + A tuple containing the current chunk and the remaining words or characters. + """ + + if self.split_units == "word": + words = current_chunk.split() + current_chunk = " ".join(words[: self.split_length]) + remaining_words = words[self.split_length :] + return current_chunk, " ".join(remaining_words) + + # split by characters + text = current_chunk + current_chunk = text[: self.split_length] + remaining_chars = text[self.split_length :] + return current_chunk, remaining_chars + + def _apply_overlap(self, chunks: List[str]) -> List[str]: + """ + Applies an overlap between consecutive chunks if the chunk_overlap attribute is greater than zero. + + Works for both word- and character-level splitting. It trims the last chunk if it exceeds the split_length and + adds the trimmed content to the next chunk. If the last chunk is still too long after trimming, it splits it + and adds the first chunk to the list. This process continues until the last chunk is within the split_length. + + :param chunks: A list of text chunks. + :returns: + A list of text chunks with the overlap applied. + """ + overlapped_chunks: List[str] = [] + + for idx, chunk in enumerate(chunks): + if idx == 0: + overlapped_chunks.append(chunk) + continue + + # get the overlap between the current and previous chunk + overlap, prev_chunk = self._get_overlap(overlapped_chunks) + if overlap == prev_chunk: + logger.warning( + "Overlap is the same as the previous chunk. " + "Consider increasing the `split_length` parameter or decreasing the `split_overlap` parameter." + ) + + # create a new chunk starting with the overlap + current_chunk = overlap + " " + chunk if self.split_units == "word" else overlap + chunk + + # if this new chunk exceeds 'split_length', trim it and move the remaining text to the next chunk + # if this is the last chunk, another new chunk will contain the trimmed text preceded by the overlap + # of the last chunk + if self._chunk_length(current_chunk) > self.split_length: + current_chunk, remaining_text = self._split_chunk(current_chunk) + if idx < len(chunks) - 1: + chunks[idx + 1] = remaining_text + (" " if self.split_units == "word" else "") + chunks[idx + 1] + elif remaining_text: + # create a new chunk with the trimmed text preceded by the overlap of the last chunk + overlapped_chunks.append(current_chunk) + chunk = remaining_text + overlap, _ = self._get_overlap(overlapped_chunks) + current_chunk = overlap + " " + chunk if self.split_units == "word" else overlap + chunk + + overlapped_chunks.append(current_chunk) + + # it can still be that the new last chunk exceeds the 'split_length' + # continue splitting until the last chunk is within 'split_length' + if idx == len(chunks) - 1 and self._chunk_length(current_chunk) > self.split_length: + last_chunk = overlapped_chunks.pop() + first_chunk, remaining_chunk = self._split_chunk(last_chunk) + overlapped_chunks.append(first_chunk) + + while remaining_chunk: + # combine overlap with remaining chunk + overlap, _ = self._get_overlap(overlapped_chunks) + current = overlap + (" " if self.split_units == "word" else "") + remaining_chunk + + # if it fits within split_length we are done + if self._chunk_length(current) <= self.split_length: + overlapped_chunks.append(current) + break + + # otherwise split it again + first_chunk, remaining_chunk = self._split_chunk(current) + overlapped_chunks.append(first_chunk) + + return overlapped_chunks + + def _get_overlap(self, overlapped_chunks: List[str]) -> Tuple[str, str]: + """Get the previous overlapped chunk instead of the original chunk.""" + prev_chunk = overlapped_chunks[-1] + overlap_start = max(0, self._chunk_length(prev_chunk) - self.split_overlap) + if self.split_units == "word": + word_chunks = prev_chunk.split() + overlap = " ".join(word_chunks[overlap_start:]) + else: + overlap = prev_chunk[overlap_start:] + return overlap, prev_chunk + + def _chunk_length(self, text: str) -> int: + """ + Split the text by whitespace and count non-empty elements. + + :param: The text to be split. + :return: The number of words in the text. + """ + + if self.split_units == "word": + words = [word for word in text.split(" ") if word] + return len(words) + + return len(text) + + def _chunk_text(self, text: str) -> List[str]: + """ + Recursive chunking algorithm that divides text into smaller chunks based on a list of separator characters. + + It starts with a list of separator characters (e.g., ["\n\n", "sentence", "\n", " "]) and attempts to divide + the text using the first separator. If the resulting chunks are still larger than the specified chunk size, + it moves to the next separator in the list. This process continues recursively, progressively applying each + specific separator until the chunks meet the desired size criteria. + + :param text: The text to be split into chunks. + :returns: + A list of text chunks. + """ + if self._chunk_length(text) <= self.split_length: + return [text] + + for curr_separator in self.separators: # type: ignore # the caller already checked that separators is not None + if curr_separator == "sentence": + # re. ignore: correct SentenceSplitter initialization is checked at the initialization of the component + sentence_with_spans = self.nltk_tokenizer.split_sentences(text) # type: ignore + splits = [sentence["sentence"] for sentence in sentence_with_spans] + else: + # add escape "\" to the separator and wrapped it in a group so that it's included in the splits as well + escaped_separator = re.escape(curr_separator) + escaped_separator = f"({escaped_separator})" + + # split the text and merge every two consecutive splits, i.e.: the text and the separator after it + splits = re.split(escaped_separator, text) + splits = [ + "".join([splits[i], splits[i + 1]]) if i < len(splits) - 1 else splits[i] + for i in range(0, len(splits), 2) + ] + + # remove last split if it's empty + splits = splits[:-1] if splits[-1] == "" else splits + + if len(splits) == 1: # go to next separator, if current separator not found in the text + continue + + chunks = [] + current_chunk: List[str] = [] + current_length = 0 + + # check splits, if any is too long, recursively chunk it, otherwise add to current chunk + for split in splits: + split_text = split + + # if adding this split exceeds chunk_size, process current_chunk + if current_length + self._chunk_length(split_text) > self.split_length: + # process current_chunk + if current_chunk: # keep the good splits + chunks.append("".join(current_chunk)) + current_chunk = [] + current_length = 0 + + # recursively handle splits that are too large + if self._chunk_length(split_text) > self.split_length: + if curr_separator == self.separators[-1]: + # tried last separator, can't split further, do a fixed-split based on word/character + fall_back_chunks = self._fall_back_to_fixed_chunking(split_text, self.split_units) + chunks.extend(fall_back_chunks) + else: + chunks.extend(self._chunk_text(split_text)) + current_length += self._chunk_length(split_text) + + else: + current_chunk.append(split_text) + current_length += self._chunk_length(split_text) + else: + current_chunk.append(split_text) + current_length += self._chunk_length(split_text) + + if current_chunk: + chunks.append("".join(current_chunk)) + + if self.split_overlap > 0: + chunks = self._apply_overlap(chunks) + + if chunks: + return chunks + + # if no separator worked, fall back to word- or character-level chunking + return self._fall_back_to_fixed_chunking(text, self.split_units) + + def _fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char"]) -> List[str]: + """ + Fall back to a fixed chunking approach if no separator works for the text. + + Splits the text into smaller chunks based on the split_length and split_units attributes, either by words or + characters. It splits into words using whitespace as a separator. + + :param text: The text to be split into chunks. + :param split_units: The unit of the split_length parameter. It can be either "word" or "char". + :returns: + A list of text chunks. + """ + chunks = [] + step = self.split_length - self.split_overlap + + if split_units == "word": + words = re.findall(r"\S+|\s+", text) + current_chunk = [] + current_length = 0 + + for word in words: + if word != " ": + current_chunk.append(word) + current_length += 1 + if current_length == step and current_chunk: + chunks.append("".join(current_chunk)) + current_chunk = [] + current_length = 0 + else: + current_chunk.append(word) + + if current_chunk: + chunks.append("".join(current_chunk)) + + else: + for i in range(0, self._chunk_length(text), step): + chunks.append(text[i : i + self.split_length]) + + return chunks + + def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: List[Document]) -> None: + prev_doc = new_docs[-1] + overlap_length = self._chunk_length(prev_doc.content) - (curr_pos - prev_doc.meta["split_idx_start"]) # type: ignore + if overlap_length > 0: + prev_doc.meta["_split_overlap"].append({"doc_id": new_doc.id, "range": (0, overlap_length)}) + new_doc.meta["_split_overlap"].append( + { + "doc_id": prev_doc.id, + "range": ( + self._chunk_length(prev_doc.content) - overlap_length, # type: ignore + self._chunk_length(prev_doc.content), # type: ignore + ), + } + ) + + def _run_one(self, doc: Document) -> List[Document]: + chunks = self._chunk_text(doc.content) # type: ignore # the caller already check for a non-empty doc.content + chunks = chunks[:-1] if len(chunks[-1]) == 0 else chunks # remove last empty chunk if it exists + current_position = 0 + current_page = 1 + + new_docs: List[Document] = [] + + for split_nr, chunk in enumerate(chunks): + new_doc = Document(content=chunk, meta=deepcopy(doc.meta)) + new_doc.meta["split_id"] = split_nr + new_doc.meta["split_idx_start"] = current_position + new_doc.meta["_split_overlap"] = [] if self.split_overlap > 0 else None + + # add overlap information to the previous and current doc + if split_nr > 0 and self.split_overlap > 0: + self._add_overlap_info(current_position, new_doc, new_docs) + + # count page breaks in the chunk + current_page += chunk.count("\f") + + # if there are consecutive page breaks at the end with no more text, adjust the page number + # e.g: "text\f\f\f" -> 3 page breaks, but current_page should be 1 + consecutive_page_breaks = len(chunk) - len(chunk.rstrip("\f")) + + if consecutive_page_breaks > 0: + new_doc.meta["page_number"] = current_page - consecutive_page_breaks + else: + new_doc.meta["page_number"] = current_page + + # keep the new chunk doc and update the current position + new_docs.append(new_doc) + current_position += len(chunk) - (self.split_overlap if split_nr < len(chunks) - 1 else 0) + + return new_docs + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]) -> Dict[str, List[Document]]: + """ + Split a list of documents into documents with smaller chunks of text. + + :param documents: List of Documents to split. + :returns: + A dictionary containing a key "documents" with a List of Documents with smaller chunks of text corresponding + to the input documents. + """ + docs = [] + for doc in documents: + if not doc.content or doc.content == "": + logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id) + continue + docs.extend(self._run_one(doc)) + + return {"documents": docs} diff --git a/haystack/components/preprocessors/sentence_tokenizer.py b/haystack/components/preprocessors/sentence_tokenizer.py index 5dd6ad97ee..9619b851fc 100644 --- a/haystack/components/preprocessors/sentence_tokenizer.py +++ b/haystack/components/preprocessors/sentence_tokenizer.py @@ -135,6 +135,7 @@ def __init__( Currently supported languages are: en, de. :param keep_white_spaces: If True, the tokenizer will keep white spaces between sentences. """ + nltk_imports.check() self.language = language self.sentence_tokenizer = load_sentence_tokenizer(language, keep_white_spaces=keep_white_spaces) self.use_split_rules = use_split_rules diff --git a/releasenotes/notes/adding-recursive-splitter-1fa716fdd77d4d8c.yaml b/releasenotes/notes/adding-recursive-splitter-1fa716fdd77d4d8c.yaml new file mode 100644 index 0000000000..aea4cd6d69 --- /dev/null +++ b/releasenotes/notes/adding-recursive-splitter-1fa716fdd77d4d8c.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Adding a `RecursiveChunker,` which uses a set of separators to split text recursively. It attempts to divide the text using the first separator, if the resulting chunks are still larger than the specified size, it moves to the next separator in the list. diff --git a/test/components/preprocessors/test_recursive_splitter.py b/test/components/preprocessors/test_recursive_splitter.py new file mode 100644 index 0000000000..8f55a75b0a --- /dev/null +++ b/test/components/preprocessors/test_recursive_splitter.py @@ -0,0 +1,818 @@ +import re + +import pytest +from pytest import LogCaptureFixture + +from haystack import Document, Pipeline +from haystack.components.preprocessors.recursive_splitter import RecursiveDocumentSplitter +from haystack.components.preprocessors.sentence_tokenizer import SentenceSplitter + + +def test_get_custom_sentence_tokenizer_success(): + tokenizer = RecursiveDocumentSplitter._get_custom_sentence_tokenizer({}) + assert isinstance(tokenizer, SentenceSplitter) + + +def test_init_with_negative_overlap(): + with pytest.raises(ValueError): + _ = RecursiveDocumentSplitter(split_length=20, split_overlap=-1, separators=["."]) + + +def test_init_with_overlap_greater_than_chunk_size(): + with pytest.raises(ValueError): + _ = RecursiveDocumentSplitter(split_length=10, split_overlap=15, separators=["."]) + + +def test_init_with_invalid_separators(): + with pytest.raises(ValueError): + _ = RecursiveDocumentSplitter(separators=[".", 2]) + + +def test_init_with_negative_split_length(): + with pytest.raises(ValueError): + _ = RecursiveDocumentSplitter(split_length=-1, separators=["."]) + + +def test_apply_overlap_no_overlap(): + # Test the case where there is no overlap between chunks + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."], split_unit="char") + chunks = ["chunk1", "chunk2", "chunk3"] + result = splitter._apply_overlap(chunks) + assert result == ["chunk1", "chunk2", "chunk3"] + + +def test_apply_overlap_with_overlap(): + # Test the case where there is overlap between chunks + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=4, separators=["."], split_unit="char") + chunks = ["chunk1", "chunk2", "chunk3"] + result = splitter._apply_overlap(chunks) + assert result == ["chunk1", "unk1chunk2", "unk2chunk3"] + + +def test_apply_overlap_with_overlap_capturing_completely_previous_chunk(caplog): + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=6, separators=["."], split_unit="char") + chunks = ["chunk1", "chunk2", "chunk3", "chunk4"] + _ = splitter._apply_overlap(chunks) + assert ( + "Overlap is the same as the previous chunk. Consider increasing the `split_length` parameter or decreasing the `split_overlap` parameter." + in caplog.text + ) + + +def test_apply_overlap_single_chunk(): + # Test the case where there is only one chunk + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=3, separators=["."], split_unit="char") + chunks = ["chunk1"] + result = splitter._apply_overlap(chunks) + assert result == ["chunk1"] + + +def test_chunk_text_smaller_than_chunk_size(): + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."]) + text = "small text" + chunks = splitter._chunk_text(text) + assert len(chunks) == 1 + assert chunks[0] == text + + +def test_chunk_text_by_period(): + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."], split_unit="char") + text = "This is a test. Another sentence. And one more." + chunks = splitter._chunk_text(text) + assert len(chunks) == 3 + assert chunks[0] == "This is a test." + assert chunks[1] == " Another sentence." + assert chunks[2] == " And one more." + + +def test_run_multiple_new_lines_unit_char(): + splitter = RecursiveDocumentSplitter(split_length=18, separators=["\n\n", "\n"], split_unit="char") + text = "This is a test.\n\n\nAnother test.\n\n\n\nFinal test." + doc = Document(content=text) + chunks = splitter.run([doc])["documents"] + assert chunks[0].content == "This is a test.\n\n" + assert chunks[1].content == "\nAnother test.\n\n\n\n" + assert chunks[2].content == "Final test." + + +def test_run_empty_documents(caplog: LogCaptureFixture): + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=0, separators=["."]) + empty_doc = Document(content="") + doc_chunks = splitter.run([empty_doc]) + doc_chunks = doc_chunks["documents"] + assert len(doc_chunks) == 0 + assert "has an empty content. Skipping this document." in caplog.text + + +def test_run_using_custom_sentence_tokenizer(): + """ + This test includes abbreviations that are not handled by the simple sentence tokenizer based on "." and requires a + more sophisticated sentence tokenizer like the one provided by NLTK. + """ + splitter = RecursiveDocumentSplitter( + split_length=400, + split_overlap=0, + split_unit="char", + separators=["\n\n", "\n", "sentence", " "], + sentence_splitter_params={"language": "en", "use_split_rules": True, "keep_white_spaces": False}, + ) + splitter.warm_up() + text = """Artificial intelligence (AI) - Introduction + +AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems. +AI technology is widely used throughout industry, government, and science. Some high-profile applications include advanced web search engines (e.g., Google Search); recommendation systems (used by YouTube, Amazon, and Netflix); interacting via human speech (e.g., Google Assistant, Siri, and Alexa); autonomous vehicles (e.g., Waymo); generative and creative tools (e.g., ChatGPT and AI art); and superhuman play and analysis in strategy games (e.g., chess and Go).""" # noqa: E501 + + chunks = splitter.run([Document(content=text)]) + chunks = chunks["documents"] + + assert len(chunks) == 4 + assert chunks[0].content == "Artificial intelligence (AI) - Introduction\n\n" + assert ( + chunks[1].content + == "AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\n" + ) # noqa: E501 + assert chunks[2].content == "AI technology is widely used throughout industry, government, and science." # noqa: E501 + assert ( + chunks[3].content + == "Some high-profile applications include advanced web search engines (e.g., Google Search); recommendation systems (used by YouTube, Amazon, and Netflix); interacting via human speech (e.g., Google Assistant, Siri, and Alexa); autonomous vehicles (e.g., Waymo); generative and creative tools (e.g., ChatGPT and AI art); and superhuman play and analysis in strategy games (e.g., chess and Go)." + ) # noqa: E501 + + +def test_run_split_by_dot_count_page_breaks_split_unit_char() -> None: + document_splitter = RecursiveDocumentSplitter(separators=["."], split_length=30, split_overlap=0, split_unit="char") + + text = ( + "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" + "Sentence on page 3. Another on page 3.\f\f Sentence on page 5." + ) + + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 7 + assert documents[0].content == "Sentence on page 1." + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + + assert documents[1].content == " Another on page 1." + assert documents[1].meta["page_number"] == 1 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + + assert documents[2].content == "\fSentence on page 2." + assert documents[2].meta["page_number"] == 2 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + + assert documents[3].content == " Another on page 2." + assert documents[3].meta["page_number"] == 2 + assert documents[3].meta["split_id"] == 3 + assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) + + assert documents[4].content == "\fSentence on page 3." + assert documents[4].meta["page_number"] == 3 + assert documents[4].meta["split_id"] == 4 + assert documents[4].meta["split_idx_start"] == text.index(documents[4].content) + + assert documents[5].content == " Another on page 3." + assert documents[5].meta["page_number"] == 3 + assert documents[5].meta["split_id"] == 5 + assert documents[5].meta["split_idx_start"] == text.index(documents[5].content) + + assert documents[6].content == "\f\f Sentence on page 5." + assert documents[6].meta["page_number"] == 5 + assert documents[6].meta["split_id"] == 6 + assert documents[6].meta["split_idx_start"] == text.index(documents[6].content) + + +def test_run_split_by_word_count_page_breaks_split_unit_char(): + splitter = RecursiveDocumentSplitter(split_length=19, split_overlap=0, separators=[" "], split_unit="char") + text = "This is some text. \f This text is on another page. \f This is the last pag3." + doc = Document(content=text) + doc_chunks = splitter.run([doc]) + doc_chunks = doc_chunks["documents"] + + assert len(doc_chunks) == 5 + assert doc_chunks[0].content == "This is some text. " + assert doc_chunks[0].meta["page_number"] == 1 + assert doc_chunks[0].meta["split_id"] == 0 + assert doc_chunks[0].meta["split_idx_start"] == text.index(doc_chunks[0].content) + + assert doc_chunks[1].content == "\f This text is on " + assert doc_chunks[1].meta["page_number"] == 2 + assert doc_chunks[1].meta["split_id"] == 1 + assert doc_chunks[1].meta["split_idx_start"] == text.index(doc_chunks[1].content) + + assert doc_chunks[2].content == "another page. \f " + assert doc_chunks[2].meta["page_number"] == 3 + assert doc_chunks[2].meta["split_id"] == 2 + assert doc_chunks[2].meta["split_idx_start"] == text.index(doc_chunks[2].content) + + assert doc_chunks[3].content == "This is the last " + assert doc_chunks[3].meta["page_number"] == 3 + assert doc_chunks[3].meta["split_id"] == 3 + assert doc_chunks[3].meta["split_idx_start"] == text.index(doc_chunks[3].content) + + assert doc_chunks[4].content == "pag3." + assert doc_chunks[4].meta["page_number"] == 3 + assert doc_chunks[4].meta["split_id"] == 4 + assert doc_chunks[4].meta["split_idx_start"] == text.index(doc_chunks[4].content) + + +def test_run_split_by_page_break_count_page_breaks() -> None: + document_splitter = RecursiveDocumentSplitter( + separators=["\f"], split_length=50, split_overlap=0, split_unit="char" + ) + + text = ( + "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" + "Sentence on page 3. Another on page 3.\f\f Sentence on page 5." + ) + + documents = document_splitter.run(documents=[Document(content=text)]) + chunks_docs = documents["documents"] + assert len(chunks_docs) == 4 + assert chunks_docs[0].content == "Sentence on page 1. Another on page 1.\f" + assert chunks_docs[0].meta["page_number"] == 1 + assert chunks_docs[0].meta["split_id"] == 0 + assert chunks_docs[0].meta["split_idx_start"] == text.index(chunks_docs[0].content) + + assert chunks_docs[1].content == "Sentence on page 2. Another on page 2.\f" + assert chunks_docs[1].meta["page_number"] == 2 + assert chunks_docs[1].meta["split_id"] == 1 + assert chunks_docs[1].meta["split_idx_start"] == text.index(chunks_docs[1].content) + + assert chunks_docs[2].content == "Sentence on page 3. Another on page 3.\f\f" + assert chunks_docs[2].meta["page_number"] == 3 + assert chunks_docs[2].meta["split_id"] == 2 + assert chunks_docs[2].meta["split_idx_start"] == text.index(chunks_docs[2].content) + + assert chunks_docs[3].content == " Sentence on page 5." + assert chunks_docs[3].meta["page_number"] == 5 + assert chunks_docs[3].meta["split_id"] == 3 + assert chunks_docs[3].meta["split_idx_start"] == text.index(chunks_docs[3].content) + + +def test_run_split_by_new_line_count_page_breaks_split_unit_char() -> None: + document_splitter = RecursiveDocumentSplitter( + separators=["\n"], split_length=21, split_overlap=0, split_unit="char" + ) + + text = ( + "Sentence on page 1.\nAnother on page 1.\n\f" + "Sentence on page 2.\nAnother on page 2.\n\f" + "Sentence on page 3.\nAnother on page 3.\n\f\f" + "Sentence on page 5." + ) + + documents = document_splitter.run(documents=[Document(content=text)]) + chunks_docs = documents["documents"] + + assert len(chunks_docs) == 7 + + assert chunks_docs[0].content == "Sentence on page 1.\n" + assert chunks_docs[0].meta["page_number"] == 1 + assert chunks_docs[0].meta["split_id"] == 0 + assert chunks_docs[0].meta["split_idx_start"] == text.index(chunks_docs[0].content) + + assert chunks_docs[1].content == "Another on page 1.\n" + assert chunks_docs[1].meta["page_number"] == 1 + assert chunks_docs[1].meta["split_id"] == 1 + assert chunks_docs[1].meta["split_idx_start"] == text.index(chunks_docs[1].content) + + assert chunks_docs[2].content == "\fSentence on page 2.\n" + assert chunks_docs[2].meta["page_number"] == 2 + assert chunks_docs[2].meta["split_id"] == 2 + assert chunks_docs[2].meta["split_idx_start"] == text.index(chunks_docs[2].content) + + assert chunks_docs[3].content == "Another on page 2.\n" + assert chunks_docs[3].meta["page_number"] == 2 + assert chunks_docs[3].meta["split_id"] == 3 + assert chunks_docs[3].meta["split_idx_start"] == text.index(chunks_docs[3].content) + + assert chunks_docs[4].content == "\fSentence on page 3.\n" + assert chunks_docs[4].meta["page_number"] == 3 + assert chunks_docs[4].meta["split_id"] == 4 + assert chunks_docs[4].meta["split_idx_start"] == text.index(chunks_docs[4].content) + + assert chunks_docs[5].content == "Another on page 3.\n" + assert chunks_docs[5].meta["page_number"] == 3 + assert chunks_docs[5].meta["split_id"] == 5 + assert chunks_docs[5].meta["split_idx_start"] == text.index(chunks_docs[5].content) + + assert chunks_docs[6].content == "\f\fSentence on page 5." + assert chunks_docs[6].meta["page_number"] == 5 + assert chunks_docs[6].meta["split_id"] == 6 + assert chunks_docs[6].meta["split_idx_start"] == text.index(chunks_docs[6].content) + + +def test_run_split_by_sentence_count_page_breaks_split_unit_char() -> None: + document_splitter = RecursiveDocumentSplitter( + separators=["sentence"], split_length=28, split_overlap=0, split_unit="char" + ) + document_splitter.warm_up() + + text = ( + "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" + "Sentence on page 3. Another on page 3.\f\fSentence on page 5." + ) + + documents = document_splitter.run(documents=[Document(content=text)]) + chunks_docs = documents["documents"] + assert len(chunks_docs) == 7 + + assert chunks_docs[0].content == "Sentence on page 1. " + assert chunks_docs[0].meta["page_number"] == 1 + assert chunks_docs[0].meta["split_id"] == 0 + assert chunks_docs[0].meta["split_idx_start"] == text.index(chunks_docs[0].content) + + assert chunks_docs[1].content == "Another on page 1.\f" + assert chunks_docs[1].meta["page_number"] == 1 + assert chunks_docs[1].meta["split_id"] == 1 + assert chunks_docs[1].meta["split_idx_start"] == text.index(chunks_docs[1].content) + + assert chunks_docs[2].content == "Sentence on page 2. " + assert chunks_docs[2].meta["page_number"] == 2 + assert chunks_docs[2].meta["split_id"] == 2 + assert chunks_docs[2].meta["split_idx_start"] == text.index(chunks_docs[2].content) + + assert chunks_docs[3].content == "Another on page 2.\f" + assert chunks_docs[3].meta["page_number"] == 2 + assert chunks_docs[3].meta["split_id"] == 3 + assert chunks_docs[3].meta["split_idx_start"] == text.index(chunks_docs[3].content) + + assert chunks_docs[4].content == "Sentence on page 3. " + assert chunks_docs[4].meta["page_number"] == 3 + assert chunks_docs[4].meta["split_id"] == 4 + assert chunks_docs[4].meta["split_idx_start"] == text.index(chunks_docs[4].content) + + assert chunks_docs[5].content == "Another on page 3.\f\f" + assert chunks_docs[5].meta["page_number"] == 3 + assert chunks_docs[5].meta["split_id"] == 5 + assert chunks_docs[5].meta["split_idx_start"] == text.index(chunks_docs[5].content) + + assert chunks_docs[6].content == "Sentence on page 5." + assert chunks_docs[6].meta["page_number"] == 5 + assert chunks_docs[6].meta["split_id"] == 6 + assert chunks_docs[6].meta["split_idx_start"] == text.index(chunks_docs[6].content) + + +def test_run_split_document_with_overlap_character_unit(): + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=10, separators=["."], split_unit="char") + text = """A simple sentence1. A bright sentence2. A clever sentence3""" + + doc = Document(content=text) + doc_chunks = splitter.run([doc]) + doc_chunks = doc_chunks["documents"] + + assert len(doc_chunks) == 5 + assert doc_chunks[0].content == "A simple sentence1." + assert doc_chunks[0].meta["split_id"] == 0 + assert doc_chunks[0].meta["split_idx_start"] == text.index(doc_chunks[0].content) + assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 10)}] + + assert doc_chunks[1].content == "sentence1. A bright " + assert doc_chunks[1].meta["split_id"] == 1 + assert doc_chunks[1].meta["split_idx_start"] == text.index(doc_chunks[1].content) + assert doc_chunks[1].meta["_split_overlap"] == [ + {"doc_id": doc_chunks[0].id, "range": (9, 19)}, + {"doc_id": doc_chunks[2].id, "range": (0, 10)}, + ] + + assert doc_chunks[2].content == " A bright sentence2." + assert doc_chunks[2].meta["split_id"] == 2 + assert doc_chunks[2].meta["split_idx_start"] == text.index(doc_chunks[2].content) + assert doc_chunks[2].meta["_split_overlap"] == [ + {"doc_id": doc_chunks[1].id, "range": (10, 20)}, + {"doc_id": doc_chunks[3].id, "range": (0, 10)}, + ] + + assert doc_chunks[3].content == "sentence2. A clever " + assert doc_chunks[3].meta["split_id"] == 3 + assert doc_chunks[3].meta["split_idx_start"] == text.index(doc_chunks[3].content) + assert doc_chunks[3].meta["_split_overlap"] == [ + {"doc_id": doc_chunks[2].id, "range": (10, 20)}, + {"doc_id": doc_chunks[4].id, "range": (0, 10)}, + ] + + assert doc_chunks[4].content == " A clever sentence3" + assert doc_chunks[4].meta["split_id"] == 4 + assert doc_chunks[4].meta["split_idx_start"] == text.index(doc_chunks[4].content) + assert doc_chunks[4].meta["_split_overlap"] == [{"doc_id": doc_chunks[3].id, "range": (10, 20)}] + + +def test_run_separator_exists_but_split_length_too_small_fall_back_to_character_chunking(): + splitter = RecursiveDocumentSplitter(separators=[" "], split_length=2, split_unit="char") + doc = Document(content="This is some text") + result = splitter.run(documents=[doc]) + assert len(result["documents"]) == 10 + for doc in result["documents"]: + if re.escape(doc.content) not in ["\ "]: + assert len(doc.content) == 2 + + +def test_run_fallback_to_character_chunking_by_default_length_too_short(): + text = "abczdefzghizjkl" + separators = ["\n\n", "\n", "z"] + splitter = RecursiveDocumentSplitter(split_length=2, separators=separators, split_unit="char") + doc = Document(content=text) + chunks = splitter.run([doc])["documents"] + for chunk in chunks: + assert len(chunk.content) <= 2 + + +def test_run_fallback_to_word_chunking_by_default_length_too_short(): + text = "This is some text. This is some more text, and even more text." + separators = ["\n\n", "\n", "."] + splitter = RecursiveDocumentSplitter(split_length=2, separators=separators, split_unit="word") + doc = Document(content=text) + chunks = splitter.run([doc])["documents"] + for chunk in chunks: + assert splitter._chunk_length(chunk.content) <= 2 + + +def test_run_custom_sentence_tokenizer_document_and_overlap_char_unit(): + """Test that RecursiveDocumentSplitter works correctly with custom sentence tokenizer and overlap""" + splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=10, separators=["sentence"], split_unit="char") + text = "This is sentence one. This is sentence two. This is sentence three." + + splitter.warm_up() + doc = Document(content=text) + doc_chunks = splitter.run([doc])["documents"] + + assert len(doc_chunks) == 4 + assert doc_chunks[0].content == "This is sentence one. " + assert doc_chunks[0].meta["split_id"] == 0 + assert doc_chunks[0].meta["split_idx_start"] == text.index(doc_chunks[0].content) + assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 10)}] + + assert doc_chunks[1].content == "ence one. This is sentenc" + assert doc_chunks[1].meta["split_id"] == 1 + assert doc_chunks[1].meta["split_idx_start"] == text.index(doc_chunks[1].content) + assert doc_chunks[1].meta["_split_overlap"] == [ + {"doc_id": doc_chunks[0].id, "range": (12, 22)}, + {"doc_id": doc_chunks[2].id, "range": (0, 10)}, + ] + + assert doc_chunks[2].content == "is sentence two. This is " + assert doc_chunks[2].meta["split_id"] == 2 + assert doc_chunks[2].meta["split_idx_start"] == text.index(doc_chunks[2].content) + assert doc_chunks[2].meta["_split_overlap"] == [ + {"doc_id": doc_chunks[1].id, "range": (15, 25)}, + {"doc_id": doc_chunks[3].id, "range": (0, 10)}, + ] + + assert doc_chunks[3].content == ". This is sentence three." + assert doc_chunks[3].meta["split_id"] == 3 + assert doc_chunks[3].meta["split_idx_start"] == text.index(doc_chunks[3].content) + assert doc_chunks[3].meta["_split_overlap"] == [{"doc_id": doc_chunks[2].id, "range": (15, 25)}] + + +def test_run_split_by_dot_count_page_breaks_word_unit() -> None: + document_splitter = RecursiveDocumentSplitter(separators=["."], split_length=4, split_overlap=0, split_unit="word") + + text = ( + "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" + "Sentence on page 3. Another on page 3.\f\f Sentence on page 5." + ) + + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 8 + assert documents[0].content == "Sentence on page 1." + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + + assert documents[1].content == " Another on page 1." + assert documents[1].meta["page_number"] == 1 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + + assert documents[2].content == "\fSentence on page 2." + assert documents[2].meta["page_number"] == 2 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + + assert documents[3].content == " Another on page 2." + assert documents[3].meta["page_number"] == 2 + assert documents[3].meta["split_id"] == 3 + assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) + + assert documents[4].content == "\fSentence on page 3." + assert documents[4].meta["page_number"] == 3 + assert documents[4].meta["split_id"] == 4 + assert documents[4].meta["split_idx_start"] == text.index(documents[4].content) + + assert documents[5].content == " Another on page 3." + assert documents[5].meta["page_number"] == 3 + assert documents[5].meta["split_id"] == 5 + assert documents[5].meta["split_idx_start"] == text.index(documents[5].content) + + assert documents[6].content == "\f\f Sentence on page" + assert documents[6].meta["page_number"] == 5 + assert documents[6].meta["split_id"] == 6 + assert documents[6].meta["split_idx_start"] == text.index(documents[6].content) + + assert documents[7].content == " 5." + assert documents[7].meta["page_number"] == 5 + assert documents[7].meta["split_id"] == 7 + assert documents[7].meta["split_idx_start"] == text.index(documents[7].content) + + +def test_run_split_by_word_count_page_breaks_word_unit(): + splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=0, separators=[" "], split_unit="word") + text = "This is some text. \f This text is on another page. \f This is the last pag3." + doc = Document(content=text) + doc_chunks = splitter.run([doc]) + doc_chunks = doc_chunks["documents"] + + assert len(doc_chunks) == 5 + assert doc_chunks[0].content == "This is some text. " + assert doc_chunks[0].meta["page_number"] == 1 + assert doc_chunks[0].meta["split_id"] == 0 + assert doc_chunks[0].meta["split_idx_start"] == text.index(doc_chunks[0].content) + + assert doc_chunks[1].content == "\f This text is " + assert doc_chunks[1].meta["page_number"] == 2 + assert doc_chunks[1].meta["split_id"] == 1 + assert doc_chunks[1].meta["split_idx_start"] == text.index(doc_chunks[1].content) + + assert doc_chunks[2].content == "on another page. \f " + assert doc_chunks[2].meta["page_number"] == 3 + assert doc_chunks[2].meta["split_id"] == 2 + assert doc_chunks[2].meta["split_idx_start"] == text.index(doc_chunks[2].content) + + assert doc_chunks[3].content == "This is the last " + assert doc_chunks[3].meta["page_number"] == 3 + assert doc_chunks[3].meta["split_id"] == 3 + assert doc_chunks[3].meta["split_idx_start"] == text.index(doc_chunks[3].content) + + assert doc_chunks[4].content == "pag3." + assert doc_chunks[4].meta["page_number"] == 3 + assert doc_chunks[4].meta["split_id"] == 4 + assert doc_chunks[4].meta["split_idx_start"] == text.index(doc_chunks[4].content) + + +def test_run_split_by_page_break_count_page_breaks_word_unit() -> None: + document_splitter = RecursiveDocumentSplitter(separators=["\f"], split_length=8, split_overlap=0, split_unit="word") + + text = ( + "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" + "Sentence on page 3. Another on page 3.\f\f Sentence on page 5." + ) + + documents = document_splitter.run(documents=[Document(content=text)]) + chunks_docs = documents["documents"] + + assert len(chunks_docs) == 4 + assert chunks_docs[0].content == "Sentence on page 1. Another on page 1.\f" + assert chunks_docs[0].meta["page_number"] == 1 + assert chunks_docs[0].meta["split_id"] == 0 + assert chunks_docs[0].meta["split_idx_start"] == text.index(chunks_docs[0].content) + + assert chunks_docs[1].content == "Sentence on page 2. Another on page 2.\f" + assert chunks_docs[1].meta["page_number"] == 2 + assert chunks_docs[1].meta["split_id"] == 1 + assert chunks_docs[1].meta["split_idx_start"] == text.index(chunks_docs[1].content) + + assert chunks_docs[2].content == "Sentence on page 3. Another on page 3.\f" + assert chunks_docs[2].meta["page_number"] == 3 + assert chunks_docs[2].meta["split_id"] == 2 + assert chunks_docs[2].meta["split_idx_start"] == text.index(chunks_docs[2].content) + + assert chunks_docs[3].content == "\f Sentence on page 5." + assert chunks_docs[3].meta["page_number"] == 5 + assert chunks_docs[3].meta["split_id"] == 3 + assert chunks_docs[3].meta["split_idx_start"] == text.index(chunks_docs[3].content) + + +def test_run_split_by_new_line_count_page_breaks_word_unit() -> None: + document_splitter = RecursiveDocumentSplitter(separators=["\n"], split_length=4, split_overlap=0, split_unit="word") + + text = ( + "Sentence on page 1.\nAnother on page 1.\n\f" + "Sentence on page 2.\nAnother on page 2.\n\f" + "Sentence on page 3.\nAnother on page 3.\n\f\f" + "Sentence on page 5." + ) + + documents = document_splitter.run(documents=[Document(content=text)]) + chunks_docs = documents["documents"] + + assert len(chunks_docs) == 7 + + assert chunks_docs[0].content == "Sentence on page 1.\n" + assert chunks_docs[0].meta["page_number"] == 1 + assert chunks_docs[0].meta["split_id"] == 0 + assert chunks_docs[0].meta["split_idx_start"] == text.index(chunks_docs[0].content) + + assert chunks_docs[1].content == "Another on page 1.\n" + assert chunks_docs[1].meta["page_number"] == 1 + assert chunks_docs[1].meta["split_id"] == 1 + assert chunks_docs[1].meta["split_idx_start"] == text.index(chunks_docs[1].content) + + assert chunks_docs[2].content == "\fSentence on page 2.\n" + assert chunks_docs[2].meta["page_number"] == 2 + assert chunks_docs[2].meta["split_id"] == 2 + assert chunks_docs[2].meta["split_idx_start"] == text.index(chunks_docs[2].content) + + assert chunks_docs[3].content == "Another on page 2.\n" + assert chunks_docs[3].meta["page_number"] == 2 + assert chunks_docs[3].meta["split_id"] == 3 + assert chunks_docs[3].meta["split_idx_start"] == text.index(chunks_docs[3].content) + + assert chunks_docs[4].content == "\fSentence on page 3.\n" + assert chunks_docs[4].meta["page_number"] == 3 + assert chunks_docs[4].meta["split_id"] == 4 + assert chunks_docs[4].meta["split_idx_start"] == text.index(chunks_docs[4].content) + + assert chunks_docs[5].content == "Another on page 3.\n" + assert chunks_docs[5].meta["page_number"] == 3 + assert chunks_docs[5].meta["split_id"] == 5 + assert chunks_docs[5].meta["split_idx_start"] == text.index(chunks_docs[5].content) + + assert chunks_docs[6].content == "\f\fSentence on page 5." + assert chunks_docs[6].meta["page_number"] == 5 + assert chunks_docs[6].meta["split_id"] == 6 + assert chunks_docs[6].meta["split_idx_start"] == text.index(chunks_docs[6].content) + + +def test_run_split_by_sentence_count_page_breaks_word_unit() -> None: + document_splitter = RecursiveDocumentSplitter( + separators=["sentence"], split_length=7, split_overlap=0, split_unit="word" + ) + document_splitter.warm_up() + + text = ( + "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" + "Sentence on page 3. Another on page 3.\f\fSentence on page 5." + ) + + documents = document_splitter.run(documents=[Document(content=text)]) + chunks_docs = documents["documents"] + assert len(chunks_docs) == 7 + + assert chunks_docs[0].content == "Sentence on page 1. " + assert chunks_docs[0].meta["page_number"] == 1 + assert chunks_docs[0].meta["split_id"] == 0 + assert chunks_docs[0].meta["split_idx_start"] == text.index(chunks_docs[0].content) + + assert chunks_docs[1].content == "Another on page 1.\f" + assert chunks_docs[1].meta["page_number"] == 1 + assert chunks_docs[1].meta["split_id"] == 1 + assert chunks_docs[1].meta["split_idx_start"] == text.index(chunks_docs[1].content) + + assert chunks_docs[2].content == "Sentence on page 2. " + assert chunks_docs[2].meta["page_number"] == 2 + assert chunks_docs[2].meta["split_id"] == 2 + assert chunks_docs[2].meta["split_idx_start"] == text.index(chunks_docs[2].content) + + assert chunks_docs[3].content == "Another on page 2.\f" + assert chunks_docs[3].meta["page_number"] == 2 + assert chunks_docs[3].meta["split_id"] == 3 + assert chunks_docs[3].meta["split_idx_start"] == text.index(chunks_docs[3].content) + + assert chunks_docs[4].content == "Sentence on page 3. " + assert chunks_docs[4].meta["page_number"] == 3 + assert chunks_docs[4].meta["split_id"] == 4 + assert chunks_docs[4].meta["split_idx_start"] == text.index(chunks_docs[4].content) + + assert chunks_docs[5].content == "Another on page 3.\f\f" + assert chunks_docs[5].meta["page_number"] == 3 + assert chunks_docs[5].meta["split_id"] == 5 + assert chunks_docs[5].meta["split_idx_start"] == text.index(chunks_docs[5].content) + + assert chunks_docs[6].content == "Sentence on page 5." + assert chunks_docs[6].meta["page_number"] == 5 + assert chunks_docs[6].meta["split_id"] == 6 + assert chunks_docs[6].meta["split_idx_start"] == text.index(chunks_docs[6].content) + + +def test_run_split_by_sentence_tokenizer_document_and_overlap_word_unit_no_overlap(): + splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=0, separators=["."], split_unit="word") + text = "This is sentence one. This is sentence two. This is sentence three." + chunks = splitter.run([Document(content=text)])["documents"] + assert len(chunks) == 3 + assert chunks[0].content == "This is sentence one." + assert chunks[1].content == " This is sentence two." + assert chunks[2].content == " This is sentence three." + + +def test_run_split_by_dot_and_overlap_1_word_unit(): + splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=1, separators=["."], split_unit="word") + text = "This is sentence one. This is sentence two. This is sentence three. This is sentence four." + chunks = splitter.run([Document(content=text)])["documents"] + assert len(chunks) == 5 + assert chunks[0].content == "This is sentence one." + assert chunks[1].content == "one. This is sentence" + assert chunks[2].content == "sentence two. This is" + assert chunks[3].content == "is sentence three. This" + assert chunks[4].content == "This is sentence four." + + +def test_run_trigger_dealing_with_remaining_word_larger_than_split_length(): + splitter = RecursiveDocumentSplitter(split_length=3, split_overlap=2, separators=["."], split_unit="word") + text = """A simple sentence1. A bright sentence2. A clever sentence3""" + doc = Document(content=text) + chunks = splitter.run([doc])["documents"] + assert len(chunks) == 7 + assert chunks[0].content == "A simple sentence1." + assert chunks[1].content == "simple sentence1. A" + assert chunks[2].content == "sentence1. A bright" + assert chunks[3].content == "A bright sentence2." + assert chunks[4].content == "bright sentence2. A" + assert chunks[5].content == "sentence2. A clever" + assert chunks[6].content == "A clever sentence3" + + +def test_run_trigger_dealing_with_remaining_char_larger_than_split_length(): + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=15, separators=["."], split_unit="char") + text = """A simple sentence1. A bright sentence2. A clever sentence3""" + doc = Document(content=text) + chunks = splitter.run([doc])["documents"] + + assert len(chunks) == 9 + + assert chunks[0].content == "A simple sentence1." + assert chunks[0].meta["split_id"] == 0 + assert chunks[0].meta["split_idx_start"] == text.index(chunks[0].content) + assert chunks[0].meta["_split_overlap"] == [{"doc_id": chunks[1].id, "range": (0, 15)}] + + assert chunks[1].content == "mple sentence1. A br" + assert chunks[1].meta["split_id"] == 1 + assert chunks[1].meta["split_idx_start"] == text.index(chunks[1].content) + assert chunks[1].meta["_split_overlap"] == [ + {"doc_id": chunks[0].id, "range": (4, 19)}, + {"doc_id": chunks[2].id, "range": (0, 15)}, + ] + + assert chunks[2].content == "sentence1. A bright " + assert chunks[2].meta["split_id"] == 2 + assert chunks[2].meta["split_idx_start"] == text.index(chunks[2].content) + assert chunks[2].meta["_split_overlap"] == [ + {"doc_id": chunks[1].id, "range": (5, 20)}, + {"doc_id": chunks[3].id, "range": (0, 15)}, + ] + + assert chunks[3].content == "nce1. A bright sente" + assert chunks[3].meta["split_id"] == 3 + assert chunks[3].meta["split_idx_start"] == text.index(chunks[3].content) + assert chunks[3].meta["_split_overlap"] == [ + {"doc_id": chunks[2].id, "range": (5, 20)}, + {"doc_id": chunks[4].id, "range": (0, 15)}, + ] + + assert chunks[4].content == " A bright sentence2." + assert chunks[4].meta["split_id"] == 4 + assert chunks[4].meta["split_idx_start"] == text.index(chunks[4].content) + assert chunks[4].meta["_split_overlap"] == [ + {"doc_id": chunks[3].id, "range": (5, 20)}, + {"doc_id": chunks[5].id, "range": (0, 15)}, + ] + + assert chunks[5].content == "ight sentence2. A cl" + assert chunks[5].meta["split_id"] == 5 + assert chunks[5].meta["split_idx_start"] == text.index(chunks[5].content) + assert chunks[5].meta["_split_overlap"] == [ + {"doc_id": chunks[4].id, "range": (5, 20)}, + {"doc_id": chunks[6].id, "range": (0, 15)}, + ] + + assert chunks[6].content == "sentence2. A clever " + assert chunks[6].meta["split_id"] == 6 + assert chunks[6].meta["split_idx_start"] == text.index(chunks[6].content) + assert chunks[6].meta["_split_overlap"] == [ + {"doc_id": chunks[5].id, "range": (5, 20)}, + {"doc_id": chunks[7].id, "range": (0, 15)}, + ] + + assert chunks[7].content == "nce2. A clever sente" + assert chunks[7].meta["split_id"] == 7 + assert chunks[7].meta["split_idx_start"] == text.index(chunks[7].content) + assert chunks[7].meta["_split_overlap"] == [ + {"doc_id": chunks[6].id, "range": (5, 20)}, + {"doc_id": chunks[8].id, "range": (0, 15)}, + ] + + assert chunks[8].content == " A clever sentence3" + assert chunks[8].meta["split_id"] == 8 + assert chunks[8].meta["split_idx_start"] == text.index(chunks[8].content) + assert chunks[8].meta["_split_overlap"] == [{"doc_id": chunks[7].id, "range": (5, 20)}] + + +def test_run_custom_split_by_dot_and_overlap_3_char_unit(): + document_splitter = RecursiveDocumentSplitter(separators=["."], split_length=4, split_overlap=0, split_unit="word") + text = "\x0c\x0c Sentence on page 5." + chunks = document_splitter._fall_back_to_fixed_chunking(text, split_units="word") + assert len(chunks) == 2 + assert chunks[0] == "\x0c\x0c Sentence on page" + assert chunks[1] == " 5." + + +def test_run_serialization_in_pipeline(): + pipeline = Pipeline() + pipeline.add_component("chunker", RecursiveDocumentSplitter(split_length=20, split_overlap=5, separators=["."])) + pipeline_dict = pipeline.dumps() + new_pipeline = Pipeline.loads(pipeline_dict) + assert pipeline_dict == new_pipeline.dumps() From db76ae28472da64ff05eba295c2fca72e5d7d3a0 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Sun, 12 Jan 2025 17:41:38 +0100 Subject: [PATCH 06/41] feat: add `default_headers` for Azure embedders (#8699) * Add default_headers param to azure embedders --- .../embedders/azure_document_embedder.py | 6 ++++ .../embedders/azure_text_embedder.py | 6 ++++ ...ders-azure-embedders-6ffd24ec1502c5e4.yaml | 4 +++ .../embedders/test_azure_document_embedder.py | 35 +++++++++++++++++++ .../embedders/test_azure_text_embedder.py | 31 ++++++++++++++++ 5 files changed, 82 insertions(+) create mode 100644 releasenotes/notes/add-default-headers-azure-embedders-6ffd24ec1502c5e4.yaml diff --git a/haystack/components/embedders/azure_document_embedder.py b/haystack/components/embedders/azure_document_embedder.py index b28fc1fda1..2223c82bf1 100644 --- a/haystack/components/embedders/azure_document_embedder.py +++ b/haystack/components/embedders/azure_document_embedder.py @@ -51,6 +51,8 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p embedding_separator: str = "\n", timeout: Optional[float] = None, max_retries: Optional[int] = None, + *, + default_headers: Optional[Dict[str, str]] = None, ): """ Creates an AzureOpenAIDocumentEmbedder component. @@ -95,6 +97,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p `OPENAI_TIMEOUT` environment variable, or 30 seconds. :param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error. If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries. + :param default_headers: Default headers to send to the AzureOpenAI client. """ # if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") @@ -119,6 +122,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p self.embedding_separator = embedding_separator self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) + self.default_headers = default_headers or {} self._client = AzureOpenAI( api_version=api_version, @@ -129,6 +133,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p organization=organization, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -161,6 +166,7 @@ def to_dict(self) -> Dict[str, Any]: azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) @classmethod diff --git a/haystack/components/embedders/azure_text_embedder.py b/haystack/components/embedders/azure_text_embedder.py index bef34d6c3f..c37bf2338d 100644 --- a/haystack/components/embedders/azure_text_embedder.py +++ b/haystack/components/embedders/azure_text_embedder.py @@ -46,6 +46,8 @@ def __init__( # pylint: disable=too-many-positional-arguments max_retries: Optional[int] = None, prefix: str = "", suffix: str = "", + *, + default_headers: Optional[Dict[str, str]] = None, ): """ Creates an AzureOpenAITextEmbedder component. @@ -82,6 +84,7 @@ def __init__( # pylint: disable=too-many-positional-arguments A string to add at the beginning of each text. :param suffix: A string to add at the end of each text. + :param default_headers: Default headers to send to the AzureOpenAI client. """ # Why is this here? # AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not @@ -105,6 +108,7 @@ def __init__( # pylint: disable=too-many-positional-arguments self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) self.prefix = prefix self.suffix = suffix + self.default_headers = default_headers or {} self._client = AzureOpenAI( api_version=api_version, @@ -115,6 +119,7 @@ def __init__( # pylint: disable=too-many-positional-arguments organization=organization, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -143,6 +148,7 @@ def to_dict(self) -> Dict[str, Any]: azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) @classmethod diff --git a/releasenotes/notes/add-default-headers-azure-embedders-6ffd24ec1502c5e4.yaml b/releasenotes/notes/add-default-headers-azure-embedders-6ffd24ec1502c5e4.yaml new file mode 100644 index 0000000000..f8a7401c39 --- /dev/null +++ b/releasenotes/notes/add-default-headers-azure-embedders-6ffd24ec1502c5e4.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added `default_headers` parameter to `AzureOpenAIDocumentEmbedder` and `AzureOpenAITextEmbedder`. diff --git a/test/components/embedders/test_azure_document_embedder.py b/test/components/embedders/test_azure_document_embedder.py index 354f35a0fc..033ed36122 100644 --- a/test/components/embedders/test_azure_document_embedder.py +++ b/test/components/embedders/test_azure_document_embedder.py @@ -22,6 +22,7 @@ def test_init_default(self, monkeypatch): assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" + assert embedder.default_headers == {} def test_to_dict(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") @@ -45,9 +46,43 @@ def test_to_dict(self, monkeypatch): "embedding_separator": "\n", "max_retries": 5, "timeout": 30.0, + "default_headers": {}, }, } + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + data = { + "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, + "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, + "api_version": "2023-05-15", + "azure_deployment": "text-embedding-ada-002", + "dimensions": None, + "azure_endpoint": "https://example-resource.azure.openai.com/", + "organization": None, + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "max_retries": 5, + "timeout": 30.0, + "default_headers": {}, + }, + } + component = AzureOpenAIDocumentEmbedder.from_dict(data) + assert component.azure_deployment == "text-embedding-ada-002" + assert component.azure_endpoint == "https://example-resource.azure.openai.com/" + assert component.api_version == "2023-05-15" + assert component.max_retries == 5 + assert component.timeout == 30.0 + assert component.prefix == "" + assert component.suffix == "" + assert component.default_headers == {} + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), diff --git a/test/components/embedders/test_azure_text_embedder.py b/test/components/embedders/test_azure_text_embedder.py index 5f1f82e3d8..f4aab7edfe 100644 --- a/test/components/embedders/test_azure_text_embedder.py +++ b/test/components/embedders/test_azure_text_embedder.py @@ -19,6 +19,7 @@ def test_init_default(self, monkeypatch): assert embedder.organization is None assert embedder.prefix == "" assert embedder.suffix == "" + assert embedder.default_headers == {} def test_to_dict(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") @@ -38,9 +39,39 @@ def test_to_dict(self, monkeypatch): "timeout": 30.0, "prefix": "", "suffix": "", + "default_headers": {}, }, } + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + data = { + "type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, + "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, + "azure_deployment": "text-embedding-ada-002", + "dimensions": None, + "organization": None, + "azure_endpoint": "https://example-resource.azure.openai.com/", + "api_version": "2023-05-15", + "max_retries": 5, + "timeout": 30.0, + "prefix": "", + "suffix": "", + "default_headers": {}, + }, + } + component = AzureOpenAITextEmbedder.from_dict(data) + assert component.azure_deployment == "text-embedding-ada-002" + assert component.azure_endpoint == "https://example-resource.azure.openai.com/" + assert component.api_version == "2023-05-15" + assert component.max_retries == 5 + assert component.timeout == 30.0 + assert component.prefix == "" + assert component.suffix == "" + assert component.default_headers == {} + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), From 642fa60cdf5679b3e7e0f5cc2908569d1004d3d2 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Mon, 13 Jan 2025 11:12:06 +0100 Subject: [PATCH 07/41] fix: PDFMinerToDocument initializes documents with content and meta (#8708) * fix: PDFMinerToDocument initializes documents with content and meta * add release note * Apply suggestions from code review Co-authored-by: David S. Batista --------- Co-authored-by: David S. Batista --- haystack/components/converters/pdfminer.py | 16 ++++++++-------- .../notes/pdfminer-docid-b9f1b2f1b936b228.yaml | 4 ++++ .../converters/test_pdfminer_to_document.py | 5 +++++ 3 files changed, 17 insertions(+), 8 deletions(-) create mode 100644 releasenotes/notes/pdfminer-docid-b9f1b2f1b936b228.yaml diff --git a/haystack/components/converters/pdfminer.py b/haystack/components/converters/pdfminer.py index 8642447816..c5f2415685 100644 --- a/haystack/components/converters/pdfminer.py +++ b/haystack/components/converters/pdfminer.py @@ -98,15 +98,15 @@ def __init__( # pylint: disable=too-many-positional-arguments ) self.store_full_path = store_full_path - def _converter(self, extractor) -> Document: + def _converter(self, extractor) -> str: """ - Extracts text from PDF pages then convert the text into Documents + Extracts text from PDF pages then converts the text into a single str :param extractor: Python generator that yields PDF pages. :returns: - PDF text converted to Haystack Document + PDF text converted to single str """ pages = [] for page in extractor: @@ -118,9 +118,9 @@ def _converter(self, extractor) -> Document: pages.append(text) # Add a page delimiter - concat = "\f".join(pages) + delimited_pages = "\f".join(pages) - return Document(content=concat) + return delimited_pages @component.output_types(documents=List[Document]) def run( @@ -157,14 +157,14 @@ def run( continue try: pdf_reader = extract_pages(io.BytesIO(bytestream.data), laparams=self.layout_params) - document = self._converter(pdf_reader) + text = self._converter(pdf_reader) except Exception as e: logger.warning( "Could not read {source} and convert it to Document, skipping. {error}", source=source, error=e ) continue - if document.content is None or document.content.strip() == "": + if text is None or text.strip() == "": logger.warning( "PDFMinerToDocument could not extract text from the file {source}. Returning an empty document.", source=source, @@ -174,7 +174,7 @@ def run( if not self.store_full_path and (file_path := bytestream.meta.get("file_path")): merged_metadata["file_path"] = os.path.basename(file_path) - document.meta = merged_metadata + document = Document(content=text, meta=merged_metadata) documents.append(document) return {"documents": documents} diff --git a/releasenotes/notes/pdfminer-docid-b9f1b2f1b936b228.yaml b/releasenotes/notes/pdfminer-docid-b9f1b2f1b936b228.yaml new file mode 100644 index 0000000000..2497c6048e --- /dev/null +++ b/releasenotes/notes/pdfminer-docid-b9f1b2f1b936b228.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + PDFMinerToDocument now creates documents with id based on converted text and meta data. Before, PDFMinerToDocument did not consider the document's meta field when generating the document's id. diff --git a/test/components/converters/test_pdfminer_to_document.py b/test/components/converters/test_pdfminer_to_document.py index 4b30f2819a..92aeb2dcd1 100644 --- a/test/components/converters/test_pdfminer_to_document.py +++ b/test/components/converters/test_pdfminer_to_document.py @@ -5,6 +5,7 @@ import pytest +from haystack import Document from haystack.dataclasses import ByteStream from haystack.components.converters.pdfminer import PDFMinerToDocument @@ -150,3 +151,7 @@ def test_run_empty_document(self, caplog, test_files_path): results = converter.run(sources=sources) assert "PDFMinerToDocument could not extract text from the file" in caplog.text assert results["documents"][0].content == "" + + # Check that not only content is used when the returned document is initialized and doc id is generated + assert results["documents"][0].meta["file_path"] == "non_text_searchable.pdf" + assert results["documents"][0].id != Document(content="").id From d147c7658f85a23677a7a25714a954c5b3bffc98 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 13 Jan 2025 11:15:33 +0100 Subject: [PATCH 08/41] feat: Add `ComponentTool` to Haystack tools (#8693) * Initial ComponentTool --------- Co-authored-by: Daria Fokina Co-authored-by: Julian Risch --- docs/pydoc/config/tools_api.yml | 2 +- haystack/tools/__init__.py | 12 +- haystack/tools/component_tool.py | 330 ++++++++++ pyproject.toml | 3 + .../add-component-tool-ffe9f9911ea055a6.yaml | 38 ++ test/tools/test_component_tool.py | 569 ++++++++++++++++++ 6 files changed, 952 insertions(+), 2 deletions(-) create mode 100644 haystack/tools/component_tool.py create mode 100644 releasenotes/notes/add-component-tool-ffe9f9911ea055a6.yaml create mode 100644 test/tools/test_component_tool.py diff --git a/docs/pydoc/config/tools_api.yml b/docs/pydoc/config/tools_api.yml index d3f953087f..3050e6c587 100644 --- a/docs/pydoc/config/tools_api.yml +++ b/docs/pydoc/config/tools_api.yml @@ -2,7 +2,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/tools] modules: - ["tool", "from_function"] + ["tool", "from_function", "component_tool"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/tools/__init__.py b/haystack/tools/__init__.py index 4601ac71c6..ccb274d49a 100644 --- a/haystack/tools/__init__.py +++ b/haystack/tools/__init__.py @@ -2,7 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: I001 (ignore import order as we need to import Tool before ComponentTool) from haystack.tools.from_function import create_tool_from_function, tool from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace +from haystack.tools.component_tool import ComponentTool -__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace", "create_tool_from_function", "tool"] + +__all__ = [ + "Tool", + "_check_duplicate_tool_names", + "deserialize_tools_inplace", + "create_tool_from_function", + "tool", + "ComponentTool", +] diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py new file mode 100644 index 0000000000..cc77ceca01 --- /dev/null +++ b/haystack/tools/component_tool.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import fields, is_dataclass +from inspect import getdoc +from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin + +from pydantic import TypeAdapter + +from haystack import logging +from haystack.core.component import Component +from haystack.core.serialization import ( + component_from_dict, + component_to_dict, + generate_qualified_class_name, + import_class_by_name, +) +from haystack.lazy_imports import LazyImport +from haystack.tools import Tool +from haystack.tools.errors import SchemaGenerationError + +with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import: + from docstring_parser import parse + + +logger = logging.getLogger(__name__) + + +class ComponentTool(Tool): + """ + A Tool that wraps Haystack components, allowing them to be used as tools by LLMs. + + ComponentTool automatically generates LLM-compatible tool schemas from component input sockets, + which are derived from the component's `run` method signature and type hints. + + + Key features: + - Automatic LLM tool calling schema generation from component input sockets + - Type conversion and validation for component inputs + - Support for types: + - Dataclasses + - Lists of dataclasses + - Basic types (str, int, float, bool, dict) + - Lists of basic types + - Automatic name generation from component class name + - Description extraction from component docstrings + + To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create. + You can create a ComponentTool from the component by passing the component to the ComponentTool constructor. + Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component. + + ```python + from haystack import component, Pipeline + from haystack.tools import ComponentTool + from haystack.components.websearch import SerperDevWebSearch + from haystack.utils import Secret + from haystack.components.tools.tool_invoker import ToolInvoker + from haystack.components.generators.chat import OpenAIChatGenerator + from haystack.dataclasses import ChatMessage + + # Create a SerperDev search component + search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3) + + # Create a tool from the component + tool = ComponentTool( + component=search, + name="web_search", # Optional: defaults to "serper_dev_web_search" + description="Search the web for current information on any topic" # Optional: defaults to component docstring + ) + + # Create pipeline with OpenAIChatGenerator and ToolInvoker + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + + # Connect components + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla") + + # Run pipeline + result = pipeline.run({"llm": {"messages": [message]}}) + + print(result) + ``` + + """ + + def __init__(self, component: Component, name: Optional[str] = None, description: Optional[str] = None): + """ + Create a Tool instance from a Haystack component. + + :param component: The Haystack component to wrap as a tool. + :param name: Optional name for the tool (defaults to snake_case of component class name). + :param description: Optional description (defaults to component's docstring). + :raises ValueError: If the component is invalid or schema generation fails. + """ + if not isinstance(component, Component): + message = ( + f"Object {component!r} is not a Haystack component. " + "Use ComponentTool only with Haystack component instances." + ) + raise ValueError(message) + + if getattr(component, "__haystack_added_to_pipeline__", None): + msg = ( + "Component has been added to a pipeline and can't be used to create a ComponentTool. " + "Create ComponentTool from a non-pipeline component instead." + ) + raise ValueError(msg) + + # Create the tools schema from the component run method parameters + tool_schema = self._create_tool_parameters_schema(component) + + def component_invoker(**kwargs): + """ + Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response. + + :param kwargs: The keyword arguments to invoke the component with. + :returns: The result of the component invocation. + """ + converted_kwargs = {} + input_sockets = component.__haystack_input__._sockets_dict + for param_name, param_value in kwargs.items(): + param_type = input_sockets[param_name].type + + # Check if the type (or list element type) has from_dict + target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type + if hasattr(target_type, "from_dict"): + if isinstance(param_value, list): + param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)] + elif isinstance(param_value, dict): + param_value = target_type.from_dict(param_value) + else: + # Let TypeAdapter handle both single values and lists + type_adapter = TypeAdapter(param_type) + param_value = type_adapter.validate_python(param_value) + + converted_kwargs[param_name] = param_value + logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}") + return component.run(**converted_kwargs) + + # Generate a name for the tool if not provided + if not name: + class_name = component.__class__.__name__ + # Convert camelCase/PascalCase to snake_case + name = "".join( + [ + "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower() + for i, c in enumerate(class_name) + ] + ).lstrip("_") + + # Generate a description for the tool if not provided and truncate to 512 characters + # as most LLMs have a limit for the description length + description = (description or component.__doc__ or name)[:512] + + # Create the Tool instance with the component invoker as the function to be called and the schema + super().__init__(name, description, tool_schema, component_invoker) + self._component = component + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the ComponentTool to a dictionary. + """ + # we do not serialize the function in this case: it can be recreated from the component at deserialization time + serialized = {"name": self.name, "description": self.description, "parameters": self.parameters} + serialized["component"] = component_to_dict(obj=self._component, name=self.name) + return {"type": generate_qualified_class_name(type(self)), "data": serialized} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Tool": + """ + Deserializes the ComponentTool from a dictionary. + """ + inner_data = data["data"] + component_class = import_class_by_name(inner_data["component"]["type"]) + component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"]) + return cls(component=component, name=inner_data["name"], description=inner_data["description"]) + + def _create_tool_parameters_schema(self, component: Component) -> Dict[str, Any]: + """ + Creates an OpenAI tools schema from a component's run method parameters. + + :param component: The component to create the schema from. + :raises SchemaGenerationError: If schema generation fails + :returns: OpenAI tools schema for the component's run method parameters. + """ + properties = {} + required = [] + + param_descriptions = self._get_param_descriptions(component.run) + + for input_name, socket in component.__haystack_input__._sockets_dict.items(): # type: ignore[attr-defined] + input_type = socket.type + description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") + + try: + property_schema = self._create_property_schema(input_type, description) + except Exception as e: + raise SchemaGenerationError( + f"Error processing input '{input_name}': {e}. " + f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, " + f"and lists of these types as input types for component's run method." + ) from e + + properties[input_name] = property_schema + + # Use socket.is_mandatory to check if the input is required + if socket.is_mandatory: + required.append(input_name) + + parameters_schema = {"type": "object", "properties": properties} + + if required: + parameters_schema["required"] = required + + return parameters_schema + + @staticmethod + def _get_param_descriptions(method: Callable) -> Dict[str, str]: + """ + Extracts parameter descriptions from the method's docstring using docstring_parser. + + :param method: The method to extract parameter descriptions from. + :returns: A dictionary mapping parameter names to their descriptions. + """ + docstring = getdoc(method) + if not docstring: + return {} + + docstring_parser_import.check() + parsed_doc = parse(docstring) + param_descriptions = {} + for param in parsed_doc.params: + if not param.description: + logger.warning( + "Missing description for parameter '%s'. Please add a description in the component's " + "run() method docstring using the format ':param %%s: '. " + "This description helps the LLM understand how to use this parameter." % param.arg_name + ) + param_descriptions[param.arg_name] = param.description.strip() if param.description else "" + return param_descriptions + + @staticmethod + def _is_nullable_type(python_type: Any) -> bool: + """ + Checks if the type is a Union with NoneType (i.e., Optional). + + :param python_type: The Python type to check. + :returns: True if the type is a Union with NoneType, False otherwise. + """ + origin = get_origin(python_type) + if origin is Union: + return type(None) in get_args(python_type) + return False + + def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a list type. + + :param item_type: The type of items in the list. + :param description: The description of the list. + :returns: A dictionary representing the list schema. + """ + items_schema = self._create_property_schema(item_type, "") + items_schema.pop("description", None) + return {"type": "array", "description": description, "items": items_schema} + + def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a dataclass. + + :param python_type: The dataclass type. + :param description: The description of the dataclass. + :returns: A dictionary representing the dataclass schema. + """ + schema = {"type": "object", "description": description, "properties": {}} + cls = python_type if isinstance(python_type, type) else python_type.__class__ + for field in fields(cls): + field_description = f"Field '{field.name}' of '{cls.__name__}'." + if isinstance(schema["properties"], dict): + schema["properties"][field.name] = self._create_property_schema(field.type, field_description) + return schema + + @staticmethod + def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a basic Python type. + + :param python_type: The Python type. + :param description: The description of the type. + :returns: A dictionary representing the basic type schema. + """ + type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} + return {"type": type_mapping.get(python_type, "string"), "description": description} + + def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: + """ + Creates a property schema for a given Python type, recursively if necessary. + + :param python_type: The Python type to create a property schema for. + :param description: The description of the property. + :param default: The default value of the property. + :returns: A dictionary representing the property schema. + :raises SchemaGenerationError: If schema generation fails, e.g., for unsupported types like Pydantic v2 models + """ + nullable = self._is_nullable_type(python_type) + if nullable: + non_none_types = [t for t in get_args(python_type) if t is not type(None)] + python_type = non_none_types[0] if non_none_types else str + + origin = get_origin(python_type) + if origin is list: + schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description) + elif is_dataclass(python_type): + schema = self._create_dataclass_schema(python_type, description) + elif hasattr(python_type, "model_validate"): + raise SchemaGenerationError( + f"Pydantic models (e.g. {python_type.__name__}) are not supported as input types for " + f"component's run method." + ) + else: + schema = self._create_basic_type_schema(python_type, description) + + if default is not None: + schema["default"] = default + + return schema diff --git a/pyproject.toml b/pyproject.toml index 73031b8130..258e4e2710 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,9 @@ extra-dependencies = [ # Structured logging "structlog", + # ComponentTool + "docstring-parser", + # Test "pytest", "pytest-bdd", diff --git a/releasenotes/notes/add-component-tool-ffe9f9911ea055a6.yaml b/releasenotes/notes/add-component-tool-ffe9f9911ea055a6.yaml new file mode 100644 index 0000000000..c9db99438a --- /dev/null +++ b/releasenotes/notes/add-component-tool-ffe9f9911ea055a6.yaml @@ -0,0 +1,38 @@ +--- +highlights: | + Introduced ComponentTool, a powerful addition to the Haystack tooling architecture that enables any Haystack component to be used as a tool by LLMs. + ComponentTool bridges the gap between Haystack's component ecosystem and LLM tool/function calling capabilities, allowing LLMs to + directly interact with components like web search, document processing, or any custom user component. ComponentTool handles + all the complexity of schema generation and type conversion, making it easy to expose component functionality to LLMs. + +features: + - | + Introduced the ComponentTool, a new tool that wraps Haystack components allowing them to be utilized as tools for LLMs (various ChatGenerators). + This ComponentTool supports automatic tool schema generation, input type conversion, and offering support for components with run methods that have input types: + - Basic types (str, int, float, bool, dict) + - Dataclasses (both simple and nested structures) + - Lists of basic types (e.g., List[str]) + - Lists of dataclasses (e.g., List[Document]) + - Parameters with mixed types (e.g., List[Document], str etc.) + + Example usage: + ```python + from haystack.components.websearch import SerperDevWebSearch + from haystack.tools import ComponentTool + from haystack.utils import Secret + + # Create a SerperDev search component + search = SerperDevWebSearch( + api_key=Secret.from_token("your-api-key"), + top_k=3 + ) + + # Create a tool from the component + tool = ComponentTool( + component=search, + name="web_search", # Optional: defaults to "serper_dev_web_search" + description="Search the web for current information" # Optional: defaults to component docstring + ) + + # You can now use the tool now in a pipeline, see docs for more examples + ``` diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py new file mode 100644 index 0000000000..38f4e20464 --- /dev/null +++ b/test/tools/test_component_tool.py @@ -0,0 +1,569 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from dataclasses import dataclass +from typing import Dict, List + +import pytest + +from haystack import Pipeline, component +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.tools.tool_invoker import ToolInvoker +from haystack.components.websearch.serper_dev import SerperDevWebSearch +from haystack.dataclasses import ChatMessage, ChatRole, Document +from haystack.tools import ComponentTool +from haystack.utils.auth import Secret + + +### Component and Model Definitions + + +@component +class SimpleComponent: + """A simple component that generates text.""" + + @component.output_types(reply=str) + def run(self, text: str) -> Dict[str, str]: + """ + A simple component that generates text. + + :param text: user's name + :return: A dictionary with the generated text. + """ + return {"reply": f"Hello, {text}!"} + + +@dataclass +class User: + """A simple user dataclass.""" + + name: str = "Anonymous" + age: int = 0 + + +@component +class UserGreeter: + """A simple component that processes a User.""" + + @component.output_types(message=str) + def run(self, user: User) -> Dict[str, str]: + """ + A simple component that processes a User. + + :param user: The User object to process. + :return: A dictionary with a message about the user. + """ + return {"message": f"User {user.name} is {user.age} years old"} + + +@component +class ListProcessor: + """A component that processes a list of strings.""" + + @component.output_types(concatenated=str) + def run(self, texts: List[str]) -> Dict[str, str]: + """ + Concatenates a list of strings into a single string. + + :param texts: The list of strings to concatenate. + :return: A dictionary with the concatenated string. + """ + return {"concatenated": " ".join(texts)} + + +@dataclass +class Address: + """A dataclass representing a physical address.""" + + street: str + city: str + + +@dataclass +class Person: + """A person with an address.""" + + name: str + address: Address + + +@component +class PersonProcessor: + """A component that processes a Person with nested Address.""" + + @component.output_types(info=str) + def run(self, person: Person) -> Dict[str, str]: + """ + Creates information about the person. + + :param person: The Person to process. + :return: A dictionary with the person's information. + """ + return {"info": f"{person.name} lives at {person.address.street}, {person.address.city}."} + + +@component +class DocumentProcessor: + """A component that processes a list of Documents.""" + + @component.output_types(concatenated=str) + def run(self, documents: List[Document], top_k: int = 5) -> Dict[str, str]: + """ + Concatenates the content of multiple documents with newlines. + + :param documents: List of Documents whose content will be concatenated + :param top_k: The number of top documents to concatenate + :returns: Dictionary containing the concatenated document contents + """ + return {"concatenated": "\n".join(doc.content for doc in documents[:top_k])} + + +## Unit tests +class TestToolComponent: + def test_from_component_basic(self): + component = SimpleComponent() + + tool = ComponentTool(component=component) + + assert tool.name == "simple_component" + assert tool.description == "A simple component that generates text." + assert tool.parameters == { + "type": "object", + "properties": {"text": {"type": "string", "description": "user's name"}}, + "required": ["text"], + } + + # Test tool invocation + result = tool.invoke(text="world") + assert isinstance(result, dict) + assert "reply" in result + assert result["reply"] == "Hello, world!" + + def test_from_component_with_dataclass(self): + component = UserGreeter() + + tool = ComponentTool(component=component) + assert tool.parameters == { + "type": "object", + "properties": { + "user": { + "type": "object", + "description": "The User object to process.", + "properties": { + "name": {"type": "string", "description": "Field 'name' of 'User'."}, + "age": {"type": "integer", "description": "Field 'age' of 'User'."}, + }, + } + }, + "required": ["user"], + } + + assert tool.name == "user_greeter" + assert tool.description == "A simple component that processes a User." + + # Test tool invocation + result = tool.invoke(user={"name": "Alice", "age": 30}) + assert isinstance(result, dict) + assert "message" in result + assert result["message"] == "User Alice is 30 years old" + + def test_from_component_with_list_input(self): + component = ListProcessor() + + tool = ComponentTool( + component=component, name="list_processing_tool", description="A tool that concatenates strings" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "texts": { + "type": "array", + "description": "The list of strings to concatenate.", + "items": {"type": "string"}, + } + }, + "required": ["texts"], + } + + # Test tool invocation + result = tool.invoke(texts=["hello", "world"]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "hello world" + + def test_from_component_with_nested_dataclass(self): + component = PersonProcessor() + + tool = ComponentTool(component=component, name="person_tool", description="A tool that processes people") + + assert tool.parameters == { + "type": "object", + "properties": { + "person": { + "type": "object", + "description": "The Person to process.", + "properties": { + "name": {"type": "string", "description": "Field 'name' of 'Person'."}, + "address": { + "type": "object", + "description": "Field 'address' of 'Person'.", + "properties": { + "street": {"type": "string", "description": "Field 'street' of 'Address'."}, + "city": {"type": "string", "description": "Field 'city' of 'Address'."}, + }, + }, + }, + } + }, + "required": ["person"], + } + + # Test tool invocation + result = tool.invoke(person={"name": "Diana", "address": {"street": "123 Elm Street", "city": "Metropolis"}}) + assert isinstance(result, dict) + assert "info" in result + assert result["info"] == "Diana lives at 123 Elm Street, Metropolis." + + def test_from_component_with_document_list(self): + component = DocumentProcessor() + + tool = ComponentTool( + component=component, name="document_processor", description="A tool that concatenates document contents" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "documents": { + "type": "array", + "description": "List of Documents whose content will be concatenated", + "items": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "Field 'id' of 'Document'."}, + "content": {"type": "string", "description": "Field 'content' of 'Document'."}, + "dataframe": {"type": "string", "description": "Field 'dataframe' of 'Document'."}, + "blob": { + "type": "object", + "description": "Field 'blob' of 'Document'.", + "properties": { + "data": {"type": "string", "description": "Field 'data' of 'ByteStream'."}, + "meta": {"type": "string", "description": "Field 'meta' of 'ByteStream'."}, + "mime_type": { + "type": "string", + "description": "Field 'mime_type' of 'ByteStream'.", + }, + }, + }, + "meta": {"type": "string", "description": "Field 'meta' of 'Document'."}, + "score": {"type": "number", "description": "Field 'score' of 'Document'."}, + "embedding": { + "type": "array", + "description": "Field 'embedding' of 'Document'.", + "items": {"type": "number"}, + }, + "sparse_embedding": { + "type": "object", + "description": "Field 'sparse_embedding' of 'Document'.", + "properties": { + "indices": { + "type": "array", + "description": "Field 'indices' of 'SparseEmbedding'.", + "items": {"type": "integer"}, + }, + "values": { + "type": "array", + "description": "Field 'values' of 'SparseEmbedding'.", + "items": {"type": "number"}, + }, + }, + }, + }, + }, + }, + "top_k": {"description": "The number of top documents to concatenate", "type": "integer"}, + }, + "required": ["documents"], + } + + # Test tool invocation + result = tool.invoke(documents=[{"content": "First document"}, {"content": "Second document"}]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "First document\nSecond document" + + def test_from_component_with_non_component(self): + class NotAComponent: + def foo(self, text: str): + return {"reply": f"Hello, {text}!"} + + not_a_component = NotAComponent() + + with pytest.raises(ValueError): + ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail") + + +## Integration tests +class TestToolComponentInPipelineWithOpenAI: + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_component_tool_in_pipeline(self): + # Create component and convert it to tool + component = SimpleComponent() + tool = ComponentTool( + component=component, name="hello_tool", description="A tool that generates a greeting message for the user" + ) + + # Create pipeline with OpenAIChatGenerator and ToolInvoker + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + + # Connect components + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Vladimir") + + # Run pipeline + result = pipeline.run({"llm": {"messages": [message]}}) + + # Check results + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Vladimir" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_user_greeter_in_pipeline(self): + component = UserGreeter() + tool = ComponentTool( + component=component, name="user_greeter", description="A tool that greets users with their name and age" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="I am Alice and I'm 30 years old") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"message": "User Alice is 30 years old"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_list_processor_in_pipeline(self): + component = ListProcessor() + tool = ComponentTool( + component=component, name="list_processor", description="A tool that concatenates a list of strings" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Can you join these words: hello, beautiful, world") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"concatenated": "hello beautiful world"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_person_processor_in_pipeline(self): + component = PersonProcessor() + tool = ComponentTool( + component=component, + name="person_processor", + description="A tool that processes information about a person and their address", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Diana lives at 123 Elm Street in Metropolis") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Diana" in tool_message.tool_call_result.result and "Metropolis" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_document_processor_in_pipeline(self): + component = DocumentProcessor() + tool = ComponentTool( + component=component, + name="document_processor", + description="A tool that concatenates the content of multiple documents", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool], convert_result_to_json_string=True)) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + result = json.loads(tool_message.tool_call_result.result) + assert "concatenated" in result + assert "Hello world" in result["concatenated"] + assert "Goodbye world" in result["concatenated"] + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_lost_in_middle_ranker_in_pipeline(self): + from haystack.components.rankers import LostInTheMiddleRanker + + component = LostInTheMiddleRanker() + tool = ComponentTool( + component=component, + name="lost_in_middle_ranker", + description="A tool that ranks documents using the Lost in the Middle algorithm and returns top k results", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.skipif(not os.environ.get("SERPERDEV_API_KEY"), reason="SERPERDEV_API_KEY not set") + @pytest.mark.integration + def test_serper_dev_web_search_in_pipeline(self): + component = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3) + tool = ComponentTool( + component=component, name="web_search", description="Search the web for current information on any topic" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + result = pipeline.run( + { + "llm": { + "messages": [ + ChatMessage.from_user(text="Use the web search tool to find information about Nikola Tesla") + ] + } + } + ) + + assert len(result["tool_invoker"]["tool_messages"]) == 1 + tool_message = result["tool_invoker"]["tool_messages"][0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Nikola Tesla" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + def test_serde_in_pipeline(self, monkeypatch): + monkeypatch.setenv("SERPERDEV_API_KEY", "test-key") + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + + # Create the search component and tool + search = SerperDevWebSearch(top_k=3) + tool = ComponentTool(component=search, name="web_search", description="Search the web for current information") + + # Create and configure the pipeline + pipeline = Pipeline() + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.connect("tool_invoker.tool_messages", "llm.messages") + + # Serialize to dict and verify structure + pipeline_dict = pipeline.to_dict() + assert ( + pipeline_dict["components"]["tool_invoker"]["type"] == "haystack.components.tools.tool_invoker.ToolInvoker" + ) + assert len(pipeline_dict["components"]["tool_invoker"]["init_parameters"]["tools"]) == 1 + + tool_dict = pipeline_dict["components"]["tool_invoker"]["init_parameters"]["tools"][0] + assert tool_dict["type"] == "haystack.tools.component_tool.ComponentTool" + assert tool_dict["data"]["name"] == "web_search" + assert tool_dict["data"]["component"]["type"] == "haystack.components.websearch.serper_dev.SerperDevWebSearch" + assert tool_dict["data"]["component"]["init_parameters"]["top_k"] == 3 + assert tool_dict["data"]["component"]["init_parameters"]["api_key"]["type"] == "env_var" + + # Test round-trip serialization + pipeline_yaml = pipeline.dumps() + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_component_tool_serde(self): + component = SimpleComponent() + + tool = ComponentTool(component=component, name="simple_tool", description="A simple tool") + + # Test serialization + tool_dict = tool.to_dict() + assert tool_dict["type"] == "haystack.tools.component_tool.ComponentTool" + assert tool_dict["data"]["name"] == "simple_tool" + assert tool_dict["data"]["description"] == "A simple tool" + assert "component" in tool_dict["data"] + + # Test deserialization + new_tool = ComponentTool.from_dict(tool_dict) + assert new_tool.name == tool.name + assert new_tool.description == tool.description + assert new_tool.parameters == tool.parameters + assert isinstance(new_tool._component, SimpleComponent) + + def test_pipeline_component_fails(self): + component = SimpleComponent() + + # Create a pipeline and add the component to it + pipeline = Pipeline() + pipeline.add_component("simple", component) + + # Try to create a tool from the component and it should fail because the component has been added to a pipeline and + # thus can't be used as tool + with pytest.raises(ValueError, match="Component has been added to a pipeline"): + ComponentTool(component=component) From ec8666545df69db7709fc9bc9e3d1e0aa122fca0 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 13 Jan 2025 11:46:34 +0100 Subject: [PATCH 09/41] docs: adding RecursiveSplitter to pydoc --- docs/pydoc/config/preprocessors_api.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pydoc/config/preprocessors_api.yml b/docs/pydoc/config/preprocessors_api.yml index c27e01be34..fb89ddda7b 100644 --- a/docs/pydoc/config/preprocessors_api.yml +++ b/docs/pydoc/config/preprocessors_api.yml @@ -1,7 +1,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/components/preprocessors] - modules: ["document_cleaner", "document_splitter", "text_cleaner", "nltk_document_splitter"] + modules: ["document_cleaner", "document_splitter", "nltk_document_splitter", "recursive_splitter", "text_cleaner"] ignore_when_discovered: ["__init__"] processors: - type: filter From ed40d9f001ed78b070bacd700509b495aa09a8de Mon Sep 17 00:00:00 2001 From: Haystack Bot <73523382+HaystackBot@users.noreply.github.com> Date: Mon, 13 Jan 2025 15:23:33 +0100 Subject: [PATCH 10/41] Update unstable version to 2.10.0-rc0 (#8713) Co-authored-by: github-actions[bot] --- VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION.txt b/VERSION.txt index da852fef72..d95663102a 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -2.9.0-rc0 +2.10.0-rc0 From 34bd31ef3265b58703aff994120deff24c24a2bd Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 14 Jan 2025 12:27:31 +0100 Subject: [PATCH 11/41] docs: fixing RecursiveSplitter pydoc markdown rendering --- .../preprocessors/recursive_splitter.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/haystack/components/preprocessors/recursive_splitter.py b/haystack/components/preprocessors/recursive_splitter.py index 3286a80d72..343bab75e9 100644 --- a/haystack/components/preprocessors/recursive_splitter.py +++ b/haystack/components/preprocessors/recursive_splitter.py @@ -34,20 +34,20 @@ class RecursiveDocumentSplitter: from haystack import Document from haystack.components.preprocessors import RecursiveDocumentSplitter - chunker = RecursiveDocumentSplitter(split_length=260, split_overlap=0, separators=["\n\n", "\n", ".", " "]) - text = '''Artificial intelligence (AI) - Introduction + chunker = RecursiveDocumentSplitter(split_length=260, split_overlap=0, separators=["\\n\\n", "\\n", ".", " "]) + text = ('''Artificial intelligence (AI) - Introduction AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems. - AI technology is widely used throughout industry, government, and science. Some high-profile applications include advanced web search engines; recommendation systems; interacting via human speech; autonomous vehicles; generative and creative tools; and superhuman play and analysis in strategy games.''' + AI technology is widely used throughout industry, government, and science. Some high-profile applications include advanced web search engines; recommendation systems; interacting via human speech; autonomous vehicles; generative and creative tools; and superhuman play and analysis in strategy games.''') chunker.warm_up() doc = Document(content=text) doc_chunks = chunker.run([doc]) print(doc_chunks["documents"]) >[ - >Document(id=..., content: 'Artificial intelligence (AI) - Introduction\n\n', meta: {'original_id': '65167a9823dd883de577e828ca4fd529e6f7241f0ff616acfce454d808478951', 'split_id': 0, 'split_idx_start': 0, '_split_overlap': []}) - >Document(id=..., content: 'AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\n', meta: {'original_id': '65167a9823dd883de577e828ca4fd529e6f7241f0ff616acfce454d808478951', 'split_id': 1, 'split_idx_start': 45, '_split_overlap': []}) - >Document(id=..., content: 'AI technology is widely used throughout industry, government, and science.', meta: {'original_id': '65167a9823dd883de577e828ca4fd529e6f7241f0ff616acfce454d808478951', 'split_id': 2, 'split_idx_start': 142, '_split_overlap': []}) - >Document(id=..., content: ' Some high-profile applications include advanced web search engines; recommendation systems; interac...', meta: {'original_id': '65167a9823dd883de577e828ca4fd529e6f7241f0ff616acfce454d808478951', 'split_id': 3, 'split_idx_start': 216, '_split_overlap': []}) + >Document(id=..., content: 'Artificial intelligence (AI) - Introduction\\n\\n', meta: {'original_id': '...', 'split_id': 0, 'split_idx_start': 0, '_split_overlap': []}) + >Document(id=..., content: 'AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\\n', meta: {'original_id': '...', 'split_id': 1, 'split_idx_start': 45, '_split_overlap': []}) + >Document(id=..., content: 'AI technology is widely used throughout industry, government, and science.', meta: {'original_id': '...', 'split_id': 2, 'split_idx_start': 142, '_split_overlap': []}) + >Document(id=..., content: ' Some high-profile applications include advanced web search engines; recommendation systems; interac...', meta: {'original_id': '...', 'split_id': 3, 'split_idx_start': 216, '_split_overlap': []}) >] ``` """ # noqa: E501 @@ -72,7 +72,7 @@ def __init__( separators will be treated as regular expressions unless the separator is "sentence", in that case the text will be split into sentences using a custom sentence tokenizer based on NLTK. See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter. - If no separators are provided, the default separators ["\n\n", "sentence", "\n", " "] are used. + If no separators are provided, the default separators ["\\n\\n", "sentence", "\\n", " "] are used. :param sentence_splitter_params: Optional parameters to pass to the sentence tokenizer. See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter for more information. From 425ce9b98f9f277c8fc14d8a468b0962f968a5e8 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 14 Jan 2025 16:47:29 +0100 Subject: [PATCH 12/41] test: updating HuggingFaceAPIChatGenerator tests --- test/components/generators/chat/test_hugging_face_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index e2158ad6e9..fa83b98db7 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -540,7 +540,7 @@ def test_live_run_with_tools(self, tools): assert "Paris" in tool_call.arguments["city"] assert message.meta["finish_reason"] == "stop" - new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22°", origin=tool_call)] # the model tends to make tool calls if provided with tools, so we don't pass them here results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) From 167ede1f4ca8996aa3fece336ba13fde85724e95 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 15 Jan 2025 09:51:52 +0100 Subject: [PATCH 13/41] remove deprecation warning from SentenceWindowRetriever (#8720) --- .../components/retrievers/sentence_window_retriever.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/haystack/components/retrievers/sentence_window_retriever.py b/haystack/components/retrievers/sentence_window_retriever.py index be1f9df100..370638e643 100644 --- a/haystack/components/retrievers/sentence_window_retriever.py +++ b/haystack/components/retrievers/sentence_window_retriever.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import warnings from typing import Any, Dict, List, Optional from haystack import Document, component, default_from_dict, default_to_dict @@ -93,13 +92,6 @@ def __init__(self, document_store: DocumentStore, window_size: int = 3): self.window_size = window_size self.document_store = document_store - warnings.warn( - "The output of `context_documents` will change in the next release. Instead of a " - "List[List[Document]], the output will be a List[Document], where the documents are ordered by " - "`split_idx_start`.", - DeprecationWarning, - ) - @staticmethod def merge_documents_text(documents: List[Document]) -> str: """ From 26b80778f52246214efe72b777d548f85d810198 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 15 Jan 2025 17:11:51 +0100 Subject: [PATCH 14/41] chore: removing NLTKDocumentSplitter (#8724) * removing NLTKDocumentSplitter * adding release notes * removing pydocs reference --- docs/pydoc/config/preprocessors_api.yml | 2 +- haystack/components/preprocessors/__init__.py | 3 +- .../preprocessors/nltk_document_splitter.py | 301 ------------- ...NLTKDocumentSplitter-b495d4e276698083.yaml | 4 + .../test_nltk_document_splitter.py | 413 ------------------ 5 files changed, 6 insertions(+), 717 deletions(-) delete mode 100644 haystack/components/preprocessors/nltk_document_splitter.py create mode 100644 releasenotes/notes/removing-NLTKDocumentSplitter-b495d4e276698083.yaml delete mode 100644 test/components/preprocessors/test_nltk_document_splitter.py diff --git a/docs/pydoc/config/preprocessors_api.yml b/docs/pydoc/config/preprocessors_api.yml index fb89ddda7b..d5a0df24c6 100644 --- a/docs/pydoc/config/preprocessors_api.yml +++ b/docs/pydoc/config/preprocessors_api.yml @@ -1,7 +1,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/components/preprocessors] - modules: ["document_cleaner", "document_splitter", "nltk_document_splitter", "recursive_splitter", "text_cleaner"] + modules: ["document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/components/preprocessors/__init__.py b/haystack/components/preprocessors/__init__.py index 33e446e8a6..26d30c1520 100644 --- a/haystack/components/preprocessors/__init__.py +++ b/haystack/components/preprocessors/__init__.py @@ -4,8 +4,7 @@ from .document_cleaner import DocumentCleaner from .document_splitter import DocumentSplitter -from .nltk_document_splitter import NLTKDocumentSplitter from .recursive_splitter import RecursiveDocumentSplitter from .text_cleaner import TextCleaner -__all__ = ["DocumentSplitter", "DocumentCleaner", "RecursiveDocumentSplitter", "TextCleaner", "NLTKDocumentSplitter"] +__all__ = ["DocumentSplitter", "DocumentCleaner", "RecursiveDocumentSplitter", "TextCleaner"] diff --git a/haystack/components/preprocessors/nltk_document_splitter.py b/haystack/components/preprocessors/nltk_document_splitter.py deleted file mode 100644 index ab787d599d..0000000000 --- a/haystack/components/preprocessors/nltk_document_splitter.py +++ /dev/null @@ -1,301 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import warnings -from copy import deepcopy -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple - -from haystack import Document, component, logging -from haystack.components.preprocessors.document_splitter import DocumentSplitter -from haystack.components.preprocessors.sentence_tokenizer import Language, SentenceSplitter, nltk_imports -from haystack.core.serialization import default_to_dict -from haystack.utils import serialize_callable - -logger = logging.getLogger(__name__) - - -@component -class NLTKDocumentSplitter(DocumentSplitter): - def __init__( # pylint: disable=too-many-positional-arguments - self, - split_by: Literal["word", "sentence", "page", "passage", "function"] = "word", - split_length: int = 200, - split_overlap: int = 0, - split_threshold: int = 0, - respect_sentence_boundary: bool = False, - language: Language = "en", - use_split_rules: bool = True, - extend_abbreviations: bool = True, - splitting_function: Optional[Callable[[str], List[str]]] = None, - ): - """ - Splits your documents using NLTK to respect sentence boundaries. - - Initialize the NLTKDocumentSplitter. - - :param split_by: Select the unit for splitting your documents. Choose from `word` for splitting by spaces (" "), - `sentence` for splitting by NLTK sentence tokenizer, `page` for splitting by the form feed ("\\f") or - `passage` for splitting by double line breaks ("\\n\\n"). - :param split_length: The maximum number of units in each split. - :param split_overlap: The number of overlapping units for each split. - :param split_threshold: The minimum number of units per split. If a split has fewer units - than the threshold, it's attached to the previous split. - :param respect_sentence_boundary: Choose whether to respect sentence boundaries when splitting by "word". - If True, uses NLTK to detect sentence boundaries, ensuring splits occur only between sentences. - :param language: Choose the language for the NLTK tokenizer. The default is English ("en"). - :param use_split_rules: Choose whether to use additional split rules when splitting by `sentence`. - :param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list - of curated abbreviations, if available. - This is currently supported for English ("en") and German ("de"). - :param splitting_function: Necessary when `split_by` is set to "function". - This is a function which must accept a single `str` as input and return a `list` of `str` as output, - representing the chunks after splitting. - """ - - warnings.warn( - "The NLTKDocumentSplitter is deprecated and will be removed in the next release. " - "See DocumentSplitter which now supports the functionalities of the NLTKDocumentSplitter, i.e.: " - "using NLTK to detect sentence boundaries.", - DeprecationWarning, - ) - - super(NLTKDocumentSplitter, self).__init__( - split_by=split_by, - split_length=split_length, - split_overlap=split_overlap, - split_threshold=split_threshold, - splitting_function=splitting_function, - ) - nltk_imports.check() - if respect_sentence_boundary and split_by != "word": - logger.warning( - "The 'respect_sentence_boundary' option is only supported for `split_by='word'`. " - "The option `respect_sentence_boundary` will be set to `False`." - ) - respect_sentence_boundary = False - self.respect_sentence_boundary = respect_sentence_boundary - self.use_split_rules = use_split_rules - self.extend_abbreviations = extend_abbreviations - self.sentence_splitter = None - self.language = language - - def warm_up(self): - """ - Warm up the NLTKDocumentSplitter by loading the sentence tokenizer. - """ - if self.sentence_splitter is None: - self.sentence_splitter = SentenceSplitter( - language=self.language, - use_split_rules=self.use_split_rules, - extend_abbreviations=self.extend_abbreviations, - keep_white_spaces=True, - ) - - def _split_into_units( - self, text: str, split_by: Literal["function", "page", "passage", "period", "sentence", "word", "line"] - ) -> List[str]: - """ - Splits the text into units based on the specified split_by parameter. - - :param text: The text to split. - :param split_by: The unit to split the text by. Choose from "word", "sentence", "passage", or "page". - :returns: A list of units. - """ - - if split_by == "page": - self.split_at = "\f" - units = text.split(self.split_at) - elif split_by == "passage": - self.split_at = "\n\n" - units = text.split(self.split_at) - elif split_by == "sentence": - # whitespace is preserved while splitting text into sentences when using keep_white_spaces=True - # so split_at is set to an empty string - self.split_at = "" - assert self.sentence_splitter is not None - result = self.sentence_splitter.split_sentences(text) - units = [sentence["sentence"] for sentence in result] - elif split_by == "word": - self.split_at = " " - units = text.split(self.split_at) - elif split_by == "function" and self.splitting_function is not None: - return self.splitting_function(text) - else: - raise NotImplementedError( - "DocumentSplitter only supports 'function', 'page', 'passage', 'sentence' or 'word' split_by options." - ) - - # Add the delimiter back to all units except the last one - for i in range(len(units) - 1): - units[i] += self.split_at - return units - - @component.output_types(documents=List[Document]) - def run(self, documents: List[Document]) -> Dict[str, List[Document]]: - """ - Split documents into smaller parts. - - Splits documents by the unit expressed in `split_by`, with a length of `split_length` - and an overlap of `split_overlap`. - - :param documents: The documents to split. - - :returns: A dictionary with the following key: - - `documents`: List of documents with the split texts. Each document includes: - - A metadata field source_id to track the original document. - - A metadata field page_number to track the original page number. - - All other metadata copied from the original document. - - :raises TypeError: if the input is not a list of Documents. - :raises ValueError: if the content of a document is None. - """ - if self.sentence_splitter is None: - raise RuntimeError( - "The component NLTKDocumentSplitter wasn't warmed up. Run 'warm_up()' before calling 'run()'." - ) - - if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): - raise TypeError("DocumentSplitter expects a List of Documents as input.") - - split_docs = [] - for doc in documents: - if doc.content is None: - raise ValueError( - f"DocumentSplitter only works with text documents but content for document ID {doc.id} is None." - ) - if doc.content == "": - logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id) - continue - - if self.respect_sentence_boundary: - units = self._split_into_units(doc.content, "sentence") - text_splits, splits_pages, splits_start_idxs = self._concatenate_sentences_based_on_word_amount( - sentences=units, split_length=self.split_length, split_overlap=self.split_overlap - ) - else: - units = self._split_into_units(doc.content, self.split_by) - text_splits, splits_pages, splits_start_idxs = self._concatenate_units( - elements=units, - split_length=self.split_length, - split_overlap=self.split_overlap, - split_threshold=self.split_threshold, - ) - metadata = deepcopy(doc.meta) - metadata["source_id"] = doc.id - split_docs += self._create_docs_from_splits( - text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata - ) - return {"documents": split_docs} - - def to_dict(self) -> Dict[str, Any]: - """ - Serializes the component to a dictionary. - """ - serialized = default_to_dict( - self, - split_by=self.split_by, - split_length=self.split_length, - split_overlap=self.split_overlap, - split_threshold=self.split_threshold, - respect_sentence_boundary=self.respect_sentence_boundary, - language=self.language, - use_split_rules=self.use_split_rules, - extend_abbreviations=self.extend_abbreviations, - ) - if self.splitting_function: - serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function) - return serialized - - @staticmethod - def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int: - """ - Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`. - - :param sentences: The list of sentences to split. - :param split_length: The maximum number of words in each split. - :param split_overlap: The number of overlapping words in each split. - :returns: The number of sentences to keep in the next chunk. - """ - # If the split_overlap is 0, we don't need to keep any sentences - if split_overlap == 0: - return 0 - - num_sentences_to_keep = 0 - num_words = 0 - # Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence - for sent in reversed(sentences[1:]): - num_words += len(sent.split()) - # If the number of words is larger than the split_length then don't add any more sentences - if num_words > split_length: - break - num_sentences_to_keep += 1 - if num_words > split_overlap: - break - return num_sentences_to_keep - - @staticmethod - def _concatenate_sentences_based_on_word_amount( - sentences: List[str], split_length: int, split_overlap: int - ) -> Tuple[List[str], List[int], List[int]]: - """ - Groups the sentences into chunks of `split_length` words while respecting sentence boundaries. - - :param sentences: The list of sentences to split. - :param split_length: The maximum number of words in each split. - :param split_overlap: The number of overlapping words in each split. - :returns: A tuple containing the concatenated sentences, the start page numbers, and the start indices. - """ - # Chunk information - chunk_word_count = 0 - chunk_starting_page_number = 1 - chunk_start_idx = 0 - current_chunk: List[str] = [] - # Output lists - split_start_page_numbers = [] - list_of_splits: List[List[str]] = [] - split_start_indices = [] - - for sentence_idx, sentence in enumerate(sentences): - current_chunk.append(sentence) - chunk_word_count += len(sentence.split()) - next_sentence_word_count = ( - len(sentences[sentence_idx + 1].split()) if sentence_idx < len(sentences) - 1 else 0 - ) - - # Number of words in the current chunk plus the next sentence is larger than the split_length - # or we reached the last sentence - if (chunk_word_count + next_sentence_word_count) > split_length or sentence_idx == len(sentences) - 1: - # Save current chunk and start a new one - list_of_splits.append(current_chunk) - split_start_page_numbers.append(chunk_starting_page_number) - split_start_indices.append(chunk_start_idx) - - # Get the number of sentences that overlap with the next chunk - num_sentences_to_keep = NLTKDocumentSplitter._number_of_sentences_to_keep( - sentences=current_chunk, split_length=split_length, split_overlap=split_overlap - ) - # Set up information for the new chunk - if num_sentences_to_keep > 0: - # Processed sentences are the ones that are not overlapping with the next chunk - processed_sentences = current_chunk[:-num_sentences_to_keep] - chunk_starting_page_number += sum(sent.count("\f") for sent in processed_sentences) - chunk_start_idx += len("".join(processed_sentences)) - # Next chunk starts with the sentences that were overlapping with the previous chunk - current_chunk = current_chunk[-num_sentences_to_keep:] - chunk_word_count = sum(len(s.split()) for s in current_chunk) - else: - # Here processed_sentences is the same as current_chunk since there is no overlap - chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk) - chunk_start_idx += len("".join(current_chunk)) - current_chunk = [] - chunk_word_count = 0 - - # Concatenate the sentences together within each split - text_splits = [] - for split in list_of_splits: - text = "".join(split) - if len(text) > 0: - text_splits.append(text) - - return text_splits, split_start_page_numbers, split_start_indices diff --git a/releasenotes/notes/removing-NLTKDocumentSplitter-b495d4e276698083.yaml b/releasenotes/notes/removing-NLTKDocumentSplitter-b495d4e276698083.yaml new file mode 100644 index 0000000000..670ec75ae4 --- /dev/null +++ b/releasenotes/notes/removing-NLTKDocumentSplitter-b495d4e276698083.yaml @@ -0,0 +1,4 @@ +--- +upgrade: + - | + Removed the deprecated `NLTKDocumentSplitter`, it's functionalities are now supported by the `DocumentSplitter`. diff --git a/test/components/preprocessors/test_nltk_document_splitter.py b/test/components/preprocessors/test_nltk_document_splitter.py deleted file mode 100644 index fe80848c74..0000000000 --- a/test/components/preprocessors/test_nltk_document_splitter.py +++ /dev/null @@ -1,413 +0,0 @@ -from typing import List - -import pytest -from haystack import Document -from pytest import LogCaptureFixture - -from haystack.components.preprocessors.nltk_document_splitter import NLTKDocumentSplitter, SentenceSplitter -from haystack.utils import deserialize_callable - - -def test_init_warning_message(caplog: LogCaptureFixture) -> None: - _ = NLTKDocumentSplitter(split_by="page", respect_sentence_boundary=True) - assert "The 'respect_sentence_boundary' option is only supported for" in caplog.text - - -def custom_split(text): - return text.split(".") - - -class TestNLTKDocumentSplitterSplitIntoUnits: - def test_document_splitter_split_into_units_word(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="word", split_length=3, split_overlap=0, split_threshold=0, language="en" - ) - - text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything." - units = document_splitter._split_into_units(text=text, split_by="word") - - assert units == [ - "Moonlight ", - "shimmered ", - "softly, ", - "wolves ", - "howled ", - "nearby, ", - "night ", - "enveloped ", - "everything.", - ] - - def test_document_splitter_split_into_units_sentence(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="sentence", split_length=2, split_overlap=0, split_threshold=0, language="en" - ) - document_splitter.warm_up() - - text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night." - units = document_splitter._split_into_units(text=text, split_by="sentence") - - assert units == [ - "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. ", - "It was a dark night.", - ] - - def test_document_splitter_split_into_units_passage(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="passage", split_length=2, split_overlap=0, split_threshold=0, language="en" - ) - - text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\n\nIt was a dark night." - units = document_splitter._split_into_units(text=text, split_by="passage") - - assert units == [ - "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\n\n", - "It was a dark night.", - ] - - def test_document_splitter_split_into_units_page(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="page", split_length=2, split_overlap=0, split_threshold=0, language="en" - ) - - text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\fIt was a dark night." - units = document_splitter._split_into_units(text=text, split_by="page") - - assert units == [ - "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\f", - "It was a dark night.", - ] - - def test_document_splitter_split_into_units_raise_error(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="word", split_length=3, split_overlap=0, split_threshold=0, language="en" - ) - - text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything." - - with pytest.raises(NotImplementedError): - document_splitter._split_into_units(text=text, split_by="invalid") # type: ignore - - -class TestNLTKDocumentSplitterNumberOfSentencesToKeep: - @pytest.mark.parametrize( - "sentences, expected_num_sentences", - [ - (["The sun set.", "Moonlight shimmered softly, wolves howled nearby, night enveloped everything."], 0), - (["The sun set.", "It was a dark night ..."], 0), - (["The sun set.", " The moon was full."], 1), - (["The sun.", " The moon."], 1), # Ignores the first sentence - (["Sun", "Moon"], 1), # Ignores the first sentence even if its inclusion would be < split_overlap - ], - ) - def test_number_of_sentences_to_keep(self, sentences: List[str], expected_num_sentences: int) -> None: - num_sentences = NLTKDocumentSplitter._number_of_sentences_to_keep( - sentences=sentences, split_length=5, split_overlap=2 - ) - assert num_sentences == expected_num_sentences - - def test_number_of_sentences_to_keep_split_overlap_zero(self) -> None: - sentences = [ - "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.", - " It was a dark night ...", - " The moon was full.", - ] - num_sentences = NLTKDocumentSplitter._number_of_sentences_to_keep( - sentences=sentences, split_length=5, split_overlap=0 - ) - assert num_sentences == 0 - - -class TestNLTKDocumentSplitterRun: - def test_run_type_error(self) -> None: - document_splitter = NLTKDocumentSplitter() - with pytest.raises(TypeError): - document_splitter.warm_up() - document_splitter.run(documents=Document(content="Moonlight shimmered softly.")) # type: ignore - - def test_run_value_error(self) -> None: - document_splitter = NLTKDocumentSplitter() - with pytest.raises(ValueError): - document_splitter.warm_up() - document_splitter.run(documents=[Document(content=None)]) - - def test_run_split_by_sentence_1(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="sentence", - split_length=2, - split_overlap=0, - split_threshold=0, - language="en", - use_split_rules=True, - extend_abbreviations=True, - ) - document_splitter.warm_up() - - text = ( - "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night ... " - "The moon was full." - ) - documents = document_splitter.run(documents=[Document(content=text)])["documents"] - - assert len(documents) == 2 - assert ( - documents[0].content == "Moonlight shimmered softly, wolves howled nearby, night enveloped " - "everything. It was a dark night ... " - ) - assert documents[1].content == "The moon was full." - - def test_run_split_by_sentence_2(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="sentence", - split_length=1, - split_overlap=0, - split_threshold=0, - language="en", - use_split_rules=False, - extend_abbreviations=True, - ) - - text = ( - "This is a test sentence with many many words that exceeds the split length and should not be repeated. " - "This is another test sentence. (This is a third test sentence.) " - "This is the last test sentence." - ) - document_splitter.warm_up() - documents = document_splitter.run(documents=[Document(content=text)])["documents"] - - assert len(documents) == 4 - assert ( - documents[0].content - == "This is a test sentence with many many words that exceeds the split length and should not be repeated. " - ) - assert documents[0].meta["page_number"] == 1 - assert documents[0].meta["split_id"] == 0 - assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) - assert documents[1].content == "This is another test sentence. " - assert documents[1].meta["page_number"] == 1 - assert documents[1].meta["split_id"] == 1 - assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) - assert documents[2].content == "(This is a third test sentence.) " - assert documents[2].meta["page_number"] == 1 - assert documents[2].meta["split_id"] == 2 - assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) - assert documents[3].content == "This is the last test sentence." - assert documents[3].meta["page_number"] == 1 - assert documents[3].meta["split_id"] == 3 - assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) - - def test_run_split_by_sentence_3(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="sentence", - split_length=1, - split_overlap=0, - split_threshold=0, - language="en", - use_split_rules=True, - extend_abbreviations=True, - ) - document_splitter.warm_up() - - text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." - documents = document_splitter.run(documents=[Document(content=text)])["documents"] - - assert len(documents) == 4 - assert documents[0].content == "Sentence on page 1.\f" - assert documents[0].meta["page_number"] == 1 - assert documents[0].meta["split_id"] == 0 - assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) - assert documents[1].content == "Sentence on page 2. \f" - assert documents[1].meta["page_number"] == 2 - assert documents[1].meta["split_id"] == 1 - assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) - assert documents[2].content == "Sentence on page 3. \f\f " - assert documents[2].meta["page_number"] == 3 - assert documents[2].meta["split_id"] == 2 - assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) - assert documents[3].content == "Sentence on page 5." - assert documents[3].meta["page_number"] == 5 - assert documents[3].meta["split_id"] == 3 - assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) - - def test_run_split_by_sentence_4(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="sentence", - split_length=2, - split_overlap=1, - split_threshold=0, - language="en", - use_split_rules=True, - extend_abbreviations=True, - ) - document_splitter.warm_up() - - text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." - documents = document_splitter.run(documents=[Document(content=text)])["documents"] - - assert len(documents) == 3 - assert documents[0].content == "Sentence on page 1.\fSentence on page 2. \f" - assert documents[0].meta["page_number"] == 1 - assert documents[0].meta["split_id"] == 0 - assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) - assert documents[1].content == "Sentence on page 2. \fSentence on page 3. \f\f " - assert documents[1].meta["page_number"] == 2 - assert documents[1].meta["split_id"] == 1 - assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) - assert documents[2].content == "Sentence on page 3. \f\f Sentence on page 5." - assert documents[2].meta["page_number"] == 3 - assert documents[2].meta["split_id"] == 2 - assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) - - -class TestNLTKDocumentSplitterRespectSentenceBoundary: - def test_run_split_by_word_respect_sentence_boundary(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="word", - split_length=3, - split_overlap=0, - split_threshold=0, - language="en", - respect_sentence_boundary=True, - ) - document_splitter.warm_up() - - text = ( - "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night.\f" - "The moon was full." - ) - documents = document_splitter.run(documents=[Document(content=text)])["documents"] - - assert len(documents) == 3 - assert documents[0].content == "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. " - assert documents[0].meta["page_number"] == 1 - assert documents[0].meta["split_id"] == 0 - assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) - assert documents[1].content == "It was a dark night.\f" - assert documents[1].meta["page_number"] == 1 - assert documents[1].meta["split_id"] == 1 - assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) - assert documents[2].content == "The moon was full." - assert documents[2].meta["page_number"] == 2 - assert documents[2].meta["split_id"] == 2 - assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) - - def test_run_split_by_word_respect_sentence_boundary_no_repeats(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="word", - split_length=13, - split_overlap=3, - split_threshold=0, - language="en", - respect_sentence_boundary=True, - use_split_rules=False, - extend_abbreviations=False, - ) - document_splitter.warm_up() - text = ( - "This is a test sentence with many many words that exceeds the split length and should not be repeated. " - "This is another test sentence. (This is a third test sentence.) " - "This is the last test sentence." - ) - documents = document_splitter.run([Document(content=text)])["documents"] - assert len(documents) == 3 - assert ( - documents[0].content - == "This is a test sentence with many many words that exceeds the split length and should not be repeated. " - ) - assert "This is a test sentence with many many words" not in documents[1].content - assert "This is a test sentence with many many words" not in documents[2].content - - def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page_breaks(self) -> None: - document_splitter = NLTKDocumentSplitter( - split_by="word", - split_length=8, - split_overlap=1, - split_threshold=0, - language="en", - use_split_rules=True, - extend_abbreviations=True, - respect_sentence_boundary=True, - ) - document_splitter.warm_up() - - text = ( - "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" - "Sentence on page 3. Another on page 3.\f\f Sentence on page 5." - ) - documents = document_splitter.run(documents=[Document(content=text)])["documents"] - - assert len(documents) == 6 - assert documents[0].content == "Sentence on page 1. Another on page 1.\f" - assert documents[0].meta["page_number"] == 1 - assert documents[0].meta["split_id"] == 0 - assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) - assert documents[1].content == "Another on page 1.\fSentence on page 2. " - assert documents[1].meta["page_number"] == 1 - assert documents[1].meta["split_id"] == 1 - assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) - assert documents[2].content == "Sentence on page 2. Another on page 2.\f" - assert documents[2].meta["page_number"] == 2 - assert documents[2].meta["split_id"] == 2 - assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) - assert documents[3].content == "Another on page 2.\fSentence on page 3. " - assert documents[3].meta["page_number"] == 2 - assert documents[3].meta["split_id"] == 3 - assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) - assert documents[4].content == "Sentence on page 3. Another on page 3.\f\f " - assert documents[4].meta["page_number"] == 3 - assert documents[4].meta["split_id"] == 4 - assert documents[4].meta["split_idx_start"] == text.index(documents[4].content) - assert documents[5].content == "Another on page 3.\f\f Sentence on page 5." - assert documents[5].meta["page_number"] == 3 - assert documents[5].meta["split_id"] == 5 - assert documents[5].meta["split_idx_start"] == text.index(documents[5].content) - - def test_to_dict(self): - splitter = NLTKDocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5) - serialized = splitter.to_dict() - - assert serialized["type"] == "haystack.components.preprocessors.nltk_document_splitter.NLTKDocumentSplitter" - assert serialized["init_parameters"]["split_by"] == "word" - assert serialized["init_parameters"]["split_length"] == 10 - assert serialized["init_parameters"]["split_overlap"] == 2 - assert serialized["init_parameters"]["split_threshold"] == 5 - assert serialized["init_parameters"]["language"] == "en" - assert serialized["init_parameters"]["use_split_rules"] is True - assert serialized["init_parameters"]["extend_abbreviations"] is True - assert "splitting_function" not in serialized["init_parameters"] - - def test_to_dict_with_splitting_function(self): - splitter = NLTKDocumentSplitter(split_by="function", splitting_function=custom_split) - serialized = splitter.to_dict() - - assert serialized["type"] == "haystack.components.preprocessors.nltk_document_splitter.NLTKDocumentSplitter" - assert serialized["init_parameters"]["split_by"] == "function" - assert "splitting_function" in serialized["init_parameters"] - assert callable(deserialize_callable(serialized["init_parameters"]["splitting_function"])) - - -class TestSentenceSplitter: - def test_apply_split_rules_second_while_loop(self) -> None: - text = "This is a test. (With a parenthetical statement.) And another sentence." - spans = [(0, 15), (16, 50), (51, 74)] - result = SentenceSplitter._apply_split_rules(text, spans) - assert len(result) == 2 - assert result == [(0, 50), (51, 74)] - - def test_apply_split_rules_no_join(self) -> None: - text = "This is a test. This is another test. And a third test." - spans = [(0, 15), (16, 36), (37, 54)] - result = SentenceSplitter._apply_split_rules(text, spans) - assert len(result) == 3 - assert result == [(0, 15), (16, 36), (37, 54)] - - @pytest.mark.parametrize( - "text,span,next_span,quote_spans,expected", - [ - # triggers sentence boundary is inside a quote - ('He said, "Hello World." Then left.', (0, 15), (16, 23), [(9, 23)], True) - ], - ) - def test_needs_join_cases(self, text, span, next_span, quote_spans, expected): - result = SentenceSplitter._needs_join(text, span, next_span, quote_spans) - assert result == expected, f"Expected {expected} for input: {text}, {span}, {next_span}, {quote_spans}" From 62ac27c947fb9f46cac6e87999eedc3d4472e5e7 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 15 Jan 2025 18:55:22 +0100 Subject: [PATCH 15/41] chore: remove deprecated `function` `ChatRole` and `from_function` class method in `ChatMessage` (#8725) * rm deprecated function role and from_function class method in chatmessage * release note --- haystack/dataclasses/chat_message.py | 29 ------------------- ...m-function-chat-role-ab401a1fb19713a7.yaml | 5 ++++ test/dataclasses/test_chat_message.py | 22 +++++--------- 3 files changed, 12 insertions(+), 44 deletions(-) create mode 100644 releasenotes/notes/rm-function-chat-role-ab401a1fb19713a7.yaml diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index a0016ac222..925259359f 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import json -import warnings from dataclasses import asdict, dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Union @@ -28,9 +27,6 @@ class ChatRole(str, Enum): #: The tool role. A message from a tool contains the result of a Tool invocation. TOOL = "tool" - #: The function role. Deprecated in favor of `TOOL`. - FUNCTION = "function" - @staticmethod def from_str(string: str) -> "ChatRole": """ @@ -128,11 +124,6 @@ def __new__(cls, *args, **kwargs): return super(ChatMessage, cls).__new__(cls) - def __post_init__(self): - if self._role == ChatRole.FUNCTION: - msg = "The `FUNCTION` role has been deprecated in favor of `TOOL` and will be removed in 2.10.0. " - warnings.warn(msg, DeprecationWarning) - def __getattribute__(self, name): """ This method is reimplemented to make the `content` attribute removal more visible. @@ -299,26 +290,6 @@ def from_tool( _meta=meta or {}, ) - @classmethod - def from_function(cls, content: str, name: str) -> "ChatMessage": - """ - Create a message from a function call. Deprecated in favor of `from_tool`. - - :param content: The text content of the message. - :param name: The name of the function being called. - :returns: A new ChatMessage instance. - """ - msg = ( - "The `from_function` method is deprecated and will be removed in version 2.10.0. " - "Its behavior has changed: it now attempts to convert legacy function messages to tool messages. " - "This conversion is not guaranteed to succeed in all scenarios. " - "Please migrate to `ChatMessage.from_tool` and carefully verify the results if you " - "continue to use this method." - ) - warnings.warn(msg) - - return cls.from_tool(content, ToolCall(id=None, tool_name=name, arguments={}), error=False) - def to_dict(self) -> Dict[str, Any]: """ Converts ChatMessage into a dictionary. diff --git a/releasenotes/notes/rm-function-chat-role-ab401a1fb19713a7.yaml b/releasenotes/notes/rm-function-chat-role-ab401a1fb19713a7.yaml new file mode 100644 index 0000000000..3636728653 --- /dev/null +++ b/releasenotes/notes/rm-function-chat-role-ab401a1fb19713a7.yaml @@ -0,0 +1,5 @@ +--- +upgrade: + - | + The deprecated `FUNCTION` role has been removed from the `ChatRole` enum. Use `TOOL` instead. + The deprecated class method `ChatMessage.from_function` has been removed. Use `ChatMessage.from_tool` instead. diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 23a214ca29..00f8e56066 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -146,17 +146,14 @@ def test_mixed_content(): assert message.tool_call == content[1] -def test_from_function(): - # check warning is raised - with pytest.warns(): - message = ChatMessage.from_function("Result of function invocation", "my_function") +def test_function_role_removed(): + with pytest.raises(ValueError): + ChatRole.from_str("function") - assert message.role == ChatRole.TOOL - assert message.tool_call_result == ToolCallResult( - result="Result of function invocation", - origin=ToolCall(id=None, tool_name="my_function", arguments={}), - error=False, - ) + +def test_from_function_class_method_removed(): + with pytest.raises(AttributeError): + ChatMessage.from_function("Result of function invocation", "my_function") def test_serde(): @@ -234,11 +231,6 @@ def test_chat_message_init_content_parameter_type(): ChatMessage(ChatRole.USER, "This is a message") -def test_chat_message_function_role_deprecated(): - with pytest.warns(DeprecationWarning): - ChatMessage(ChatRole.FUNCTION, TextContent("This is a message")) - - def test_to_openai_dict_format(): message = ChatMessage.from_system("You are good assistant") assert message.to_openai_dict_format() == {"role": "system", "content": "You are good assistant"} From 21dd03d3e79774de89d772d2ce37dec11e2cd5e4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 17 Jan 2025 09:58:45 +0100 Subject: [PATCH 16/41] feat: Add completion start time timestamp to relevant generators (#8728) * OpenAIChatGenerator - add completion_start_time * HuggingFaceAPIChatGenerator - add completion_start_time * Add tests * Add reno note * Relax condition for cached responses * Add completion_start_time timestamping to non-chat generators * Update haystack/components/generators/chat/hugging_face_api.py Co-authored-by: Stefano Fiorucci * PR feedback --------- Co-authored-by: Stefano Fiorucci --- .../generators/chat/hugging_face_api.py | 6 +++++ haystack/components/generators/chat/openai.py | 3 +++ .../components/generators/hugging_face_api.py | 9 ++++++++ haystack/components/generators/openai.py | 14 +++++++---- ...completion-timestamp-c0ad3b8698a2d575.yaml | 4 ++++ .../generators/chat/test_hugging_face_api.py | 11 ++++++--- .../components/generators/chat/test_openai.py | 4 ++++ .../generators/test_hugging_face_api.py | 23 +++++++++++++++++++ test/components/generators/test_openai.py | 4 ++++ 9 files changed, 71 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/add-streaming-completion-timestamp-c0ad3b8698a2d575.yaml diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 50a730a01f..1264272fca 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from datetime import datetime from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -259,6 +260,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict ) generated_text = "" + first_chunk_time = None for chunk in api_output: # n is unused, so the API always returns only one choice @@ -276,6 +278,9 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict if finish_reason: meta["finish_reason"] = finish_reason + if first_chunk_time is None: + first_chunk_time = datetime.now().isoformat() + stream_chunk = StreamingChunk(text, meta) self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) @@ -285,6 +290,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict "finish_reason": finish_reason, "index": 0, "usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming + "completion_start_time": first_chunk_time, } ) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 0b699e3bc1..b30de1b43d 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -4,6 +4,7 @@ import json import os +from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union from openai import OpenAI, Stream @@ -381,6 +382,7 @@ def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[Str "model": chunk.model, "index": 0, "finish_reason": chunk.choices[0].finish_reason, + "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received "usage": {}, # we don't have usage data for streaming responses } @@ -444,6 +446,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio "index": choice.index, "tool_calls": choice.delta.tool_calls, "finish_reason": choice.finish_reason, + "received_at": datetime.now().isoformat(), } ) return chunk_message diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index a44ad94575..0a977f1603 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import asdict +from datetime import datetime from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -217,18 +218,26 @@ def _stream_and_build_response( self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None] ): chunks: List[StreamingChunk] = [] + first_chunk_time = None + for chunk in hf_output: token: TextGenerationOutputToken = chunk.token if token.special: continue + chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})} + if first_chunk_time is None: + first_chunk_time = datetime.now().isoformat() + stream_chunk = StreamingChunk(token.text, chunk_metadata) chunks.append(stream_chunk) streaming_callback(stream_chunk) + metadata = { "finish_reason": chunks[-1].meta.get("finish_reason", None), "model": self._client.model, "usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)}, + "completion_start_time": first_chunk_time, } return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]} diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index d2f07f9d85..3a87b8c068 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os +from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union from openai import OpenAI, Stream @@ -255,7 +256,7 @@ def _create_message_from_chunks( "model": completion_chunk.model, "index": 0, "finish_reason": finish_reason, - # Usage is available when streaming only if the user explicitly requests it + "completion_start_time": streamed_chunks[0].meta.get("received_at"), # first chunk received "usage": dict(completion_chunk.usage or {}), } ) @@ -296,12 +297,17 @@ def _build_chunk(chunk: Any) -> StreamingChunk: :returns: The StreamingChunk. """ - # function or tools calls are not going to happen in non-chat generation - # as users can not send ChatMessage with function or tools calls choice = chunk.choices[0] content = choice.delta.content or "" chunk_message = StreamingChunk(content) - chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason}) + chunk_message.meta.update( + { + "model": chunk.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + "received_at": datetime.now().isoformat(), + } + ) return chunk_message @staticmethod diff --git a/releasenotes/notes/add-streaming-completion-timestamp-c0ad3b8698a2d575.yaml b/releasenotes/notes/add-streaming-completion-timestamp-c0ad3b8698a2d575.yaml new file mode 100644 index 0000000000..2718c6fdcd --- /dev/null +++ b/releasenotes/notes/add-streaming-completion-timestamp-c0ad3b8698a2d575.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added completion_start_time metadata to track time-to-first-token (TTFT) in streaming responses from Hugging Face API and OpenAI (Azure). diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index fa83b98db7..f9e306c46e 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from datetime import datetime import os from unittest.mock import MagicMock, Mock, patch @@ -503,9 +504,13 @@ def test_live_run_serverless_streaming(self): assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "usage" in response["replies"][0].meta - assert "prompt_tokens" in response["replies"][0].meta["usage"] - assert "completion_tokens" in response["replies"][0].meta["usage"] + + response_meta = response["replies"][0].meta + assert "completion_start_time" in response_meta + assert datetime.fromisoformat(response_meta["completion_start_time"]) <= datetime.now() + assert "usage" in response_meta + assert "prompt_tokens" in response_meta["usage"] + assert "completion_tokens" in response_meta["usage"] @pytest.mark.integration @pytest.mark.skipif( diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index eb50d92739..63a920a8ec 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -546,6 +546,10 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + # check that the completion_start_time is set and valid ISO format + assert "completion_start_time" in message.meta + assert datetime.fromisoformat(message.meta["completion_start_time"]) < datetime.now() + @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", diff --git a/test/components/generators/test_hugging_face_api.py b/test/components/generators/test_hugging_face_api.py index 0f4be2f9cb..965cf0cf81 100644 --- a/test/components/generators/test_hugging_face_api.py +++ b/test/components/generators/test_hugging_face_api.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os from unittest.mock import MagicMock, Mock, patch +from datetime import datetime import pytest from huggingface_hub import ( @@ -312,3 +313,25 @@ def test_run_serverless(self): assert isinstance(response["meta"], list) assert len(response["meta"]) > 0 assert [isinstance(meta, dict) for meta in response["meta"]] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + def test_live_run_streaming_check_completion_start_time(self): + generator = HuggingFaceAPIGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + streaming_callback=streaming_callback_handler, + ) + + results = generator.run("What is the capital of France?") + + assert len(results["replies"]) == 1 + assert "Paris" in results["replies"][0] + + # Verify completion start time in final metadata + assert "completion_start_time" in results["meta"][0] + completion_start = datetime.fromisoformat(results["meta"][0]["completion_start_time"]) + assert completion_start <= datetime.now() diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index e1d865c95f..816412ee97 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from datetime import datetime import logging import os from typing import List @@ -286,6 +287,9 @@ def __call__(self, chunk: StreamingChunk) -> None: assert "gpt-4o-mini" in metadata["model"] assert metadata["finish_reason"] == "stop" + assert "completion_start_time" in metadata + assert datetime.fromisoformat(metadata["completion_start_time"]) <= datetime.now() + # unfortunately, the usage is not available for streaming calls # we keep the key in the metadata for compatibility assert "usage" in metadata and len(metadata["usage"]) == 0 From 2c84266d8fde2a0e408eb0e42916a9f60e61a941 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 17 Jan 2025 10:56:16 +0100 Subject: [PATCH 17/41] test: adding test for PyPDF to extract passages so that they are detect by DocumentSplitter (#8739) --- haystack/components/converters/pypdf.py | 19 ++++++++-------- .../converters/test_pypdf_to_document.py | 22 +++++++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/haystack/components/converters/pypdf.py b/haystack/components/converters/pypdf.py index 334ef097d7..15bbcc1fec 100644 --- a/haystack/components/converters/pypdf.py +++ b/haystack/components/converters/pypdf.py @@ -158,17 +158,16 @@ def from_dict(cls, data): def _default_convert(self, reader: "PdfReader") -> str: texts = [] for page in reader.pages: - texts.append( - page.extract_text( - orientations=self.plain_mode_orientations, - extraction_mode=self.extraction_mode.value, - space_width=self.plain_mode_space_width, - layout_mode_space_vertically=self.layout_mode_space_vertically, - layout_mode_scale_weight=self.layout_mode_scale_weight, - layout_mode_strip_rotated=self.layout_mode_strip_rotated, - layout_mode_font_height_weight=self.layout_mode_font_height_weight, - ) + extracted_text = page.extract_text( + orientations=self.plain_mode_orientations, + extraction_mode=self.extraction_mode.value, + space_width=self.plain_mode_space_width, + layout_mode_space_vertically=self.layout_mode_space_vertically, + layout_mode_scale_weight=self.layout_mode_scale_weight, + layout_mode_strip_rotated=self.layout_mode_strip_rotated, + layout_mode_font_height_weight=self.layout_mode_font_height_weight, ) + texts.append(extracted_text) text = "\f".join(texts) return text diff --git a/test/components/converters/test_pypdf_to_document.py b/test/components/converters/test_pypdf_to_document.py index 916bb771ee..6306f0659e 100644 --- a/test/components/converters/test_pypdf_to_document.py +++ b/test/components/converters/test_pypdf_to_document.py @@ -8,6 +8,7 @@ from haystack import Document, default_from_dict, default_to_dict from haystack.components.converters.pypdf import PyPDFToDocument, PyPDFExtractionMode +from haystack.components.preprocessors import DocumentSplitter from haystack.dataclasses import ByteStream @@ -213,3 +214,24 @@ def test_run_empty_document(self, caplog, test_files_path): # Check that meta is used when the returned document is initialized and thus when doc id is generated assert output["documents"][0].meta["file_path"] == "non_text_searchable.pdf" assert output["documents"][0].id != Document(content="").id + + def test_run_detect_paragraphs_to_be_used_in_split_passage(self, test_files_path): + converter = PyPDFToDocument(extraction_mode=PyPDFExtractionMode.LAYOUT) + sources = [test_files_path / "pdf" / "sample_pdf_2.pdf"] + pdf_doc = converter.run(sources=sources) + splitter = DocumentSplitter(split_length=1, split_by="passage") + docs = splitter.run(pdf_doc["documents"]) + + assert len(docs["documents"]) == 51 + + expected = ( + "A wiki (/ˈwɪki/ (About this soundlisten) WIK-ee) is a hypertext publication collaboratively\n" + "edited and managed by its own audience directly using a web browser. A typical wiki\ncontains " + "multiple pages for the subjects or scope of the project and may be either open\nto the public or " + "limited to use within an organization for maintaining its internal knowledge\nbase. Wikis are " + "enabled by wiki software, otherwise known as wiki engines. A wiki engine,\nbeing a form of a " + "content management system, differs from other web-based systems\nsuch as blog software, in that " + "the content is created without any defined owner or leader,\nand wikis have little inherent " + "structure, allowing structure to emerge according to the\nneeds of the users.[1]\n\n" + ) + assert docs["documents"][2].content == expected From 424bce2783977346ea408f54607f02dd3d8cedbe Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 17 Jan 2025 13:36:07 +0100 Subject: [PATCH 18/41] test: fix HF API flaky live test with tools (#8744) * test: fix HF API flaky live test with tools * rm print --- .../generators/chat/test_hugging_face_api.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index f9e306c46e..6e46e5041b 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -466,6 +466,7 @@ def test_run_with_tools(self, mock_check_valid_model, tools): not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) + @pytest.mark.flaky(reruns=3, reruns_delay=10) def test_live_run_serverless(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, @@ -489,6 +490,7 @@ def test_live_run_serverless(self): not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) + @pytest.mark.flaky(reruns=3, reruns_delay=10) def test_live_run_serverless_streaming(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, @@ -517,19 +519,18 @@ def test_live_run_serverless_streaming(self): not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - @pytest.mark.integration + @pytest.mark.flaky(reruns=3, reruns_delay=10) def test_live_run_with_tools(self, tools): """ We test the round trip: generate tool call, pass tool message, generate response. - The model used here (zephyr-7b-beta) is always available and not gated. - Even if it does not officially support tools, TGI+HF API make it work. + The model used here (Hermes-3-Llama-3.1-8B) is not gated and kept in a warm state. """ - chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")] + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + api_params={"model": "NousResearch/Hermes-3-Llama-3.1-8B"}, generation_kwargs={"temperature": 0.5}, ) @@ -545,7 +546,7 @@ def test_live_run_with_tools(self, tools): assert "Paris" in tool_call.arguments["city"] assert message.meta["finish_reason"] == "stop" - new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22°", origin=tool_call)] + new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] # the model tends to make tool calls if provided with tools, so we don't pass them here results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) From 5af2888e23dab5be9960e05594752b530c110048 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 17 Jan 2025 14:01:16 +0100 Subject: [PATCH 19/41] fix: `PDFMinerToDocument` convert function - adding double new lines between each `container_text` so that passages can be detected. (#8729) * initial import * adding double new lines between container_texts so that passages can be detected * reducing type specification to avoid import error * adding release notes * renaming variable --- haystack/components/converters/pdfminer.py | 18 ++++++----- ...or-passage-detection-62cf5c3e9758bcf9.yaml | 4 +++ .../converters/test_pdfminer_to_document.py | 30 +++++++++++++++++++ 3 files changed, 45 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/fixing-PDFMiner-for-passage-detection-62cf5c3e9758bcf9.yaml diff --git a/haystack/components/converters/pdfminer.py b/haystack/components/converters/pdfminer.py index c5f2415685..6c8fc6cdc7 100644 --- a/haystack/components/converters/pdfminer.py +++ b/haystack/components/converters/pdfminer.py @@ -5,7 +5,7 @@ import io import os from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union from haystack import Document, component, logging from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata @@ -98,23 +98,27 @@ def __init__( # pylint: disable=too-many-positional-arguments ) self.store_full_path = store_full_path - def _converter(self, extractor) -> str: + @staticmethod + def _converter(lt_page_objs: Iterator) -> str: """ Extracts text from PDF pages then converts the text into a single str - :param extractor: + :param lt_page_objs: Python generator that yields PDF pages. :returns: PDF text converted to single str """ pages = [] - for page in extractor: + for page in lt_page_objs: text = "" for container in page: # Keep text only if isinstance(container, LTTextContainer): - text += container.get_text() + container_text = container.get_text() + if container_text: + text += "\n\n" + text += container_text pages.append(text) # Add a page delimiter @@ -156,8 +160,8 @@ def run( logger.warning("Could not read {source}. Skipping it. Error: {error}", source=source, error=e) continue try: - pdf_reader = extract_pages(io.BytesIO(bytestream.data), laparams=self.layout_params) - text = self._converter(pdf_reader) + pages = extract_pages(io.BytesIO(bytestream.data), laparams=self.layout_params) + text = self._converter(pages) except Exception as e: logger.warning( "Could not read {source} and convert it to Document, skipping. {error}", source=source, error=e diff --git a/releasenotes/notes/fixing-PDFMiner-for-passage-detection-62cf5c3e9758bcf9.yaml b/releasenotes/notes/fixing-PDFMiner-for-passage-detection-62cf5c3e9758bcf9.yaml new file mode 100644 index 0000000000..5b791e9749 --- /dev/null +++ b/releasenotes/notes/fixing-PDFMiner-for-passage-detection-62cf5c3e9758bcf9.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Updated `PDFMinerToDocument` convert function to to double new lines between container_text so that passages can later by `DocumentSplitter`. diff --git a/test/components/converters/test_pdfminer_to_document.py b/test/components/converters/test_pdfminer_to_document.py index 92aeb2dcd1..4691a2a1a2 100644 --- a/test/components/converters/test_pdfminer_to_document.py +++ b/test/components/converters/test_pdfminer_to_document.py @@ -6,6 +6,7 @@ import pytest from haystack import Document +from haystack.components.preprocessors import DocumentSplitter from haystack.dataclasses import ByteStream from haystack.components.converters.pdfminer import PDFMinerToDocument @@ -155,3 +156,32 @@ def test_run_empty_document(self, caplog, test_files_path): # Check that not only content is used when the returned document is initialized and doc id is generated assert results["documents"][0].meta["file_path"] == "non_text_searchable.pdf" assert results["documents"][0].id != Document(content="").id + + def test_run_detect_pages_and_split_by_passage(self, test_files_path): + converter = PDFMinerToDocument() + sources = [test_files_path / "pdf" / "sample_pdf_2.pdf"] + pdf_doc = converter.run(sources=sources) + splitter = DocumentSplitter(split_length=1, split_by="page") + docs = splitter.run(pdf_doc["documents"]) + assert len(docs["documents"]) == 4 + + def test_run_detect_paragraphs_to_be_used_in_split_passage(self, test_files_path): + converter = PDFMinerToDocument() + sources = [test_files_path / "pdf" / "sample_pdf_2.pdf"] + pdf_doc = converter.run(sources=sources) + splitter = DocumentSplitter(split_length=1, split_by="passage") + docs = splitter.run(pdf_doc["documents"]) + + assert len(docs["documents"]) == 29 + + expected = ( + "\nA wiki (/ˈwɪki/ (About this soundlisten) WIK-ee) is a hypertext publication collaboratively" + " \nedited and managed by its own audience directly using a web browser. A typical wiki \ncontains " + "multiple pages for the subjects or scope of the project and may be either open \nto the public or " + "limited to use within an organization for maintaining its internal knowledge \nbase. Wikis are " + "enabled by wiki software, otherwise known as wiki engines. A wiki engine, \nbeing a form of a " + "content management system, differs from other web-based systems \nsuch as blog software, in that " + "the content is created without any defined owner or leader, \nand wikis have little inherent " + "structure, allowing structure to emerge according to the \nneeds of the users.[1] \n\n" + ) + assert docs["documents"][6].content == expected From 6feb3856bb73d0a1fec6b7434ffdd93fd287f05c Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Sun, 19 Jan 2025 17:28:37 +0100 Subject: [PATCH 20/41] chore: Remove FixMe comment from __init__.py (#8749) --- haystack/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/haystack/__init__.py b/haystack/__init__.py index a8712024c8..2f275d3a46 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -30,5 +30,3 @@ "GeneratedAnswer", "ExtractedAnswer", ] - -# FIXME: remove before merging PR From 242138c68b5dc8d704df57ed4466060cf90b4457 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Sun, 19 Jan 2025 20:45:02 +0100 Subject: [PATCH 21/41] chore: update ruff version in pre-commit hook (#8746) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 306e205e41..ac1a8a7d34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: args: [--markdown-linebreak-ext=md] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.9.2 hooks: - id: ruff - id: ruff-format From 542a7f7ef5638dcec9e53d7a8d66ee04d0e7359c Mon Sep 17 00:00:00 2001 From: Nicola Procopio Date: Mon, 20 Jan 2025 09:51:47 +0100 Subject: [PATCH 22/41] fix: update meta data before initializing new Document in DocumentSplitter (#8745) * updated DocumentSplitter issue #8741 * release note * updated DocumentSplitter in _create_docs_from_splits function initialize a new variable copied_mete instead to overwrite meta * added test test_duplicate_pages_get_different_doc_id * fix fmt --------- Co-authored-by: Stefano Fiorucci --- haystack/components/preprocessors/document_splitter.py | 10 +++++----- .../updated-documentsplitter-762c4409cbc296e6.yaml | 4 ++++ .../components/preprocessors/test_document_splitter.py | 8 ++++++++ 3 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 releasenotes/notes/updated-documentsplitter-762c4409cbc296e6.yaml diff --git a/haystack/components/preprocessors/document_splitter.py b/haystack/components/preprocessors/document_splitter.py index d03897b4b6..949f756ae6 100644 --- a/haystack/components/preprocessors/document_splitter.py +++ b/haystack/components/preprocessors/document_splitter.py @@ -323,11 +323,11 @@ def _create_docs_from_splits( documents: List[Document] = [] for i, (txt, split_idx) in enumerate(zip(text_splits, splits_start_idxs)): - meta = deepcopy(meta) - doc = Document(content=txt, meta=meta) - doc.meta["page_number"] = splits_pages[i] - doc.meta["split_id"] = i - doc.meta["split_idx_start"] = split_idx + copied_meta = deepcopy(meta) + copied_meta["page_number"] = splits_pages[i] + copied_meta["split_id"] = i + copied_meta["split_idx_start"] = split_idx + doc = Document(content=txt, meta=copied_meta) documents.append(doc) if self.split_overlap <= 0: diff --git a/releasenotes/notes/updated-documentsplitter-762c4409cbc296e6.yaml b/releasenotes/notes/updated-documentsplitter-762c4409cbc296e6.yaml new file mode 100644 index 0000000000..f0d3a3d68e --- /dev/null +++ b/releasenotes/notes/updated-documentsplitter-762c4409cbc296e6.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Updated Document's meta data after initializing the Document in DocumentSplitter as requested in issue #8741 diff --git a/test/components/preprocessors/test_document_splitter.py b/test/components/preprocessors/test_document_splitter.py index f9096239f2..81e0fa2ae4 100644 --- a/test/components/preprocessors/test_document_splitter.py +++ b/test/components/preprocessors/test_document_splitter.py @@ -827,3 +827,11 @@ def test_respect_sentence_boundary_serialization(self): assert deserialized.respect_sentence_boundary == True assert hasattr(deserialized, "sentence_splitter") assert deserialized.language == "de" + + def test_duplicate_pages_get_different_doc_id(self): + splitter = DocumentSplitter(split_by="page", split_length=1) + doc1 = Document(content="This is some text.\fThis is some text.\fThis is some text.\fThis is some text.") + splitter.warm_up() + result = splitter.run(documents=[doc1]) + + assert len({doc.id for doc in result["documents"]}) == 4 From 2bf6bf6a45d02df588702ac12508da08e27fa7f8 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 21 Jan 2025 10:07:56 +0100 Subject: [PATCH 23/41] build: add `jsonschema` library to core dependencies (#8753) * add jsonschema to core dependencies * release note --- haystack/components/validators/json_schema.py | 7 ++----- haystack/tools/tool.py | 9 +++------ pyproject.toml | 4 +--- .../jsonschema-core-dependency-d38645d819eb0d2d.yaml | 4 ++++ 4 files changed, 10 insertions(+), 14 deletions(-) create mode 100644 releasenotes/notes/jsonschema-core-dependency-d38645d819eb0d2d.yaml diff --git a/haystack/components/validators/json_schema.py b/haystack/components/validators/json_schema.py index 0a449aff42..051773113c 100644 --- a/haystack/components/validators/json_schema.py +++ b/haystack/components/validators/json_schema.py @@ -5,12 +5,10 @@ import json from typing import Any, Dict, List, Optional +from jsonschema import ValidationError, validate + from haystack import component from haystack.dataclasses import ChatMessage -from haystack.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: - from jsonschema import ValidationError, validate def is_valid_json(s: str) -> bool: @@ -110,7 +108,6 @@ def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: the messages' content is validated. :param error_template: A custom template string for formatting the error message in case of validation failure. """ - jsonschema_import.check() self.json_schema = json_schema self.error_template = error_template diff --git a/haystack/tools/tool.py b/haystack/tools/tool.py index bdb8f005b6..fd4802879f 100644 --- a/haystack/tools/tool.py +++ b/haystack/tools/tool.py @@ -5,15 +5,13 @@ from dataclasses import asdict, dataclass from typing import Any, Callable, Dict, List, Optional +from jsonschema import Draft202012Validator +from jsonschema.exceptions import SchemaError + from haystack.core.serialization import generate_qualified_class_name, import_class_by_name -from haystack.lazy_imports import LazyImport from haystack.tools.errors import ToolInvocationError from haystack.utils import deserialize_callable, serialize_callable -with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: - from jsonschema import Draft202012Validator - from jsonschema.exceptions import SchemaError - @dataclass class Tool: @@ -39,7 +37,6 @@ class Tool: function: Callable def __post_init__(self): - jsonschema_import.check() # Check that the parameters define a valid JSON schema try: Draft202012Validator.check_schema(self.parameters) diff --git a/pyproject.toml b/pyproject.toml index 258e4e2710..481f3546d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dependencies = [ "requests", "numpy", "python-dateutil", + "jsonschema", # JsonSchemaValidator, Tool "haystack-experimental", ] @@ -116,9 +117,6 @@ extra-dependencies = [ "jsonref", # OpenAPIServiceConnector, OpenAPIServiceToFunctions "openapi3", - # JsonSchemaValidator, Tool - "jsonschema", - # Tracing "opentelemetry-sdk", "ddtrace", diff --git a/releasenotes/notes/jsonschema-core-dependency-d38645d819eb0d2d.yaml b/releasenotes/notes/jsonschema-core-dependency-d38645d819eb0d2d.yaml new file mode 100644 index 0000000000..f8579ebcdd --- /dev/null +++ b/releasenotes/notes/jsonschema-core-dependency-d38645d819eb0d2d.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add jsonschema library as a core dependency. It is used in Tool and JsonSchemaValidator. From f96839e139c3e4395a847481aa4512343508e602 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 21 Jan 2025 14:43:27 +0100 Subject: [PATCH 24/41] chore: update `transformers` test dependency (#8752) * update transformers test dependency * add pad_token_id to the mock tokenizer * fix HFLocal test + new test --- pyproject.toml | 2 +- .../generators/chat/test_hugging_face_local.py | 1 + .../test_hugging_face_local_generator.py | 18 +++++++++++++++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 481f3546d9..397ce83583 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ format-check = "ruff format --check {args}" extra-dependencies = [ "numpy>=2", # Haystack is compatible both with numpy 1.x and 2.x, but we test with 2.x - "transformers[torch,sentencepiece]==4.44.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... + "transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index c953404912..16d706b35f 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -42,6 +42,7 @@ def mock_pipeline_tokenizer(): # Mocking the tokenizer mock_tokenizer = Mock(spec=PreTrainedTokenizer) mock_tokenizer.encode.return_value = ["Berlin", "is", "cool"] + mock_tokenizer.pad_token_id = 100 mock_pipeline.tokenizer = mock_tokenizer return mock_pipeline diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index bded2e8d47..8d64bc44fc 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -397,8 +397,12 @@ def test_stop_words_criteria_with_a_mocked_tokenizer(self): # "This is ambiguously, but is unrelated." input_ids_one = torch.LongTensor([[100, 19, 24621, 11937, 6, 68, 19, 73, 3897, 5]]) input_ids_two = torch.LongTensor([[100, 19, 73, 24621, 11937]]) # "This is unambiguously" - stop_words_criteria = StopWordsCriteria(tokenizer=Mock(spec=PreTrainedTokenizerFast), stop_words=["mock data"]) + + mock_tokenizer = Mock(spec=PreTrainedTokenizerFast) + mock_tokenizer.pad_token = "" + stop_words_criteria = StopWordsCriteria(tokenizer=mock_tokenizer, stop_words=["mock data"]) stop_words_criteria.stop_ids = stop_words_id + assert not stop_words_criteria(input_ids_one, scores=None) assert stop_words_criteria(input_ids_two, scores=None) @@ -459,3 +463,15 @@ def test_hf_pipeline_runs_with_our_criteria(self): results = generator.run(prompt="something that triggers something") assert results["replies"] != [] assert generator.stopping_criteria_list is not None + + @pytest.mark.integration + @pytest.mark.flaky(reruns=3, reruns_delay=10) + def test_live_run(self): + llm = HuggingFaceLocalGenerator(model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50}) + llm.warm_up() + + result = llm.run(prompt="Please create a summary about the following topic: Climate change") + + assert "replies" in result + assert isinstance(result["replies"][0], str) + assert "climate change" in result["replies"][0].lower() From c3d0643511b5d5a9c2507bad9ad082eb36115076 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 23 Jan 2025 10:24:04 +0100 Subject: [PATCH 25/41] feat: `AzureOpenAIChatGenerator` - support for tools (#8757) * feat: AzureOpenAIChatGenerator - support for tools * release note * feedback --- haystack/components/generators/chat/azure.py | 20 ++- ...echatgenerator-tools-9622b7c96e452404.yaml | 4 + test/components/generators/chat/test_azure.py | 130 +++++++++++++++++- 3 files changed, 146 insertions(+), 8 deletions(-) create mode 100644 releasenotes/notes/azurechatgenerator-tools-9622b7c96e452404.yaml diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index 3a27078e57..2cfde29e9c 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional # pylint: disable=import-error from openai.lib.azure import AzureOpenAI @@ -11,6 +11,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import StreamingChunk +from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable logger = logging.getLogger(__name__) @@ -75,6 +76,8 @@ def __init__( # pylint: disable=too-many-positional-arguments max_retries: Optional[int] = None, generation_kwargs: Optional[Dict[str, Any]] = None, default_headers: Optional[Dict[str, str]] = None, + tools: Optional[List[Tool]] = None, + tools_strict: bool = False, ): """ Initialize the Azure OpenAI Chat Generator component. @@ -112,6 +115,11 @@ def __init__( # pylint: disable=too-many-positional-arguments - `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the values are the bias to add to that token. :param default_headers: Default headers to use for the AzureOpenAI client. + :param tools: + A list of tools for which the model can prepare calls. + :param tools_strict: + Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly + the schema provided in the `parameters` field of the tool definition, but this may increase latency. """ # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact # with the API. @@ -142,10 +150,9 @@ def __init__( # pylint: disable=too-many-positional-arguments self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) self.default_headers = default_headers or {} - # This ChatGenerator does not yet supports tools. The following workaround ensures that we do not - # get an error when invoking the run method of the parent class (OpenAIChatGenerator). - self.tools = None - self.tools_strict = False + _check_duplicate_tool_names(tools) + self.tools = tools + self.tools_strict = tools_strict self.client = AzureOpenAI( api_version=api_version, @@ -180,6 +187,8 @@ def to_dict(self) -> Dict[str, Any]: api_key=self.api_key.to_dict() if self.api_key is not None else None, azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, default_headers=self.default_headers, + tools=[tool.to_dict() for tool in self.tools] if self.tools else None, + tools_strict=self.tools_strict, ) @classmethod @@ -192,6 +201,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIChatGenerator": The deserialized component instance. """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: diff --git a/releasenotes/notes/azurechatgenerator-tools-9622b7c96e452404.yaml b/releasenotes/notes/azurechatgenerator-tools-9622b7c96e452404.yaml new file mode 100644 index 0000000000..b5bc2fc829 --- /dev/null +++ b/releasenotes/notes/azurechatgenerator-tools-9622b7c96e452404.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add support for Tools in the Azure OpenAI Chat Generator. diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index c104d0e725..3622013a4e 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -9,11 +9,25 @@ from haystack import Pipeline from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage +from haystack.dataclasses import ChatMessage, ToolCall +from haystack.tools.tool import Tool from haystack.utils.auth import Secret -class TestOpenAIChatGenerator: +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] + + +class TestAzureOpenAIChatGenerator: def test_init_default(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") @@ -28,17 +42,21 @@ def test_init_fail_wo_api_key(self, monkeypatch): with pytest.raises(OpenAIError): AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") - def test_init_with_parameters(self): + def test_init_with_parameters(self, tools): component = AzureOpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), azure_endpoint="some-non-existing-endpoint", streaming_callback=print_streaming_chunk, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + tools=tools, + tools_strict=True, ) assert component.client.api_key == "test-api-key" assert component.azure_deployment == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.tools == tools + assert component.tools_strict def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") @@ -58,6 +76,8 @@ def test_to_dict_default(self, monkeypatch): "timeout": 30.0, "max_retries": 5, "default_headers": {}, + "tools": None, + "tools_strict": False, }, } @@ -85,15 +105,94 @@ def test_to_dict_with_parameters(self, monkeypatch): "timeout": 2.5, "max_retries": 10, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": None, + "tools_strict": False, "default_headers": {}, }, } + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_OPENAI_AD_TOKEN", "test-ad-token") + data = { + "type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, + "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, + "api_version": "2023-05-15", + "azure_endpoint": "some-non-existing-endpoint", + "azure_deployment": "gpt-4o-mini", + "organization": None, + "streaming_callback": None, + "generation_kwargs": {}, + "timeout": 30.0, + "max_retries": 5, + "default_headers": {}, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + }, + } + ], + "tools_strict": False, + }, + } + + generator = AzureOpenAIChatGenerator.from_dict(data) + assert isinstance(generator, AzureOpenAIChatGenerator) + + assert generator.api_key == Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False) + assert generator.azure_ad_token == Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False) + assert generator.api_version == "2023-05-15" + assert generator.azure_endpoint == "some-non-existing-endpoint" + assert generator.azure_deployment == "gpt-4o-mini" + assert generator.organization is None + assert generator.streaming_callback is None + assert generator.generation_kwargs == {} + assert generator.timeout == 30.0 + assert generator.max_retries == 5 + assert generator.default_headers == {} + assert generator.tools == [ + Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + ] + assert generator.tools_strict == False + def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") p = Pipeline() p.add_component(instance=generator, name="generator") + + assert p.to_dict() == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator", + "init_parameters": { + "azure_endpoint": "some-non-existing-endpoint", + "azure_deployment": "gpt-4o-mini", + "organization": None, + "api_version": "2023-05-15", + "streaming_callback": None, + "generation_kwargs": {}, + "timeout": 30.0, + "max_retries": 5, + "api_key": {"type": "env_var", "env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False}, + "azure_ad_token": {"type": "env_var", "env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False}, + "default_headers": {}, + "tools": None, + "tools_strict": False, + }, + } + }, + "connections": [], + } p_str = p.dumps() q = Pipeline.loads(p_str) assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed." @@ -117,4 +216,29 @@ def test_live_run(self): assert "gpt-4o-mini" in message.meta["model"] assert message.meta["finish_reason"] == "stop" + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None), + reason=( + "Please export env variables called AZURE_OPENAI_API_KEY containing " + "the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing " + "the Azure OpenAI endpoint URL to run this test." + ), + ) + def test_live_run_with_tools(self, tools): + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AzureOpenAIChatGenerator(organization="HaystackCI", tools=tools) + results = component.run(chat_messages) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert not message.texts + assert not message.text + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" + # additional tests intentionally omitted as they are covered by test_openai.py From bf79f0493282f2b9de70a68478b93a3cd07e348a Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Thu, 23 Jan 2025 12:14:32 +0100 Subject: [PATCH 26/41] feat: support streaming_callback as run param for HF Chat generators (#8763) * feat: support streaming_callback as run param for HF Chat generators * add tests --- .../generators/chat/hugging_face_api.py | 18 ++++-- .../generators/chat/hugging_face_local.py | 13 +++- ...r-hf-chat-generators-68aaa7e540ad03ce.yaml | 4 ++ .../generators/chat/test_hugging_face_api.py | 64 +++++++++++++++++++ .../chat/test_hugging_face_local.py | 41 ++++++++++++ 5 files changed, 133 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/streaming-callback-run-param-support-for-hf-chat-generators-68aaa7e540ad03ce.yaml diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 1264272fca..9fb5bda1f3 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -220,6 +220,7 @@ def run( messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Invoke the text generation inference based on the provided messages and generation parameters. @@ -231,6 +232,9 @@ def run( :param tools: A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set during component initialization. + :param streaming_callback: + An optional callable for handling streaming responses. If set, it will override the `streaming_callback` + parameter set during component initialization. :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage objects. """ @@ -245,8 +249,9 @@ def run( raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") _check_duplicate_tool_names(tools) - if self.streaming_callback: - return self._run_streaming(formatted_messages, generation_kwargs) + streaming_callback = streaming_callback or self.streaming_callback + if streaming_callback: + return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback) hf_tools = None if tools: @@ -254,7 +259,12 @@ def run( return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) - def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]): + def _run_streaming( + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + streaming_callback: Callable[[StreamingChunk], None], + ): api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion( messages, stream=True, **generation_kwargs ) @@ -282,7 +292,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict first_chunk_time = datetime.now().isoformat() stream_chunk = StreamingChunk(text, meta) - self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) + streaming_callback(stream_chunk) meta.update( { diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index a79a6dcfa8..d5d05ae487 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -233,12 +233,18 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator": return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Invoke text generation inference based on the provided messages and generation parameters. :param messages: A list of ChatMessage objects representing the input messages. :param generation_kwargs: Additional keyword arguments for text generation. + :param streaming_callback: An optional callable for handling streaming responses. :returns: A list containing the generated responses as ChatMessage instances. """ @@ -259,7 +265,8 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, if stop_words_criteria: generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria]) - if self.streaming_callback: + streaming_callback = streaming_callback or self.streaming_callback + if streaming_callback: num_responses = generation_kwargs.get("num_return_sequences", 1) if num_responses > 1: msg = ( @@ -270,7 +277,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, logger.warning(msg, num_responses=num_responses) generation_kwargs["num_return_sequences"] = 1 # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming - generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words) + generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) hf_messages = [convert_message_to_hf_format(message) for message in messages] diff --git a/releasenotes/notes/streaming-callback-run-param-support-for-hf-chat-generators-68aaa7e540ad03ce.yaml b/releasenotes/notes/streaming-callback-run-param-support-for-hf-chat-generators-68aaa7e540ad03ce.yaml new file mode 100644 index 0000000000..8d5e6007f5 --- /dev/null +++ b/releasenotes/notes/streaming-callback-run-param-support-for-hf-chat-generators-68aaa7e540ad03ce.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Streaming callback run param support for HF chat generators. diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 6e46e5041b..9edec01212 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -395,6 +395,70 @@ def mock_iter(self): assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + def test_run_with_streaming_callback_in_run_method( + self, mock_check_valid_model, mock_chat_completion, chat_messages + ): + streaming_call_count = 0 + + # Define the streaming callback function + def streaming_callback_fn(chunk: StreamingChunk): + nonlocal streaming_call_count + streaming_call_count += 1 + assert isinstance(chunk, StreamingChunk) + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + ) + + # Create a fake streamed response + # self needed here, don't remove + def mock_iter(self): + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), + index=0, + finish_reason=None, + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, + ) + + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, + ) + + mock_response = Mock(**{"__iter__": mock_iter}) + mock_chat_completion.return_value = mock_response + + # Generate text response with streaming callback + response = generator.run(chat_messages, streaming_callback=streaming_callback_fn) + + # check kwargs passed to text_generation + _, kwargs = mock_chat_completion.call_args + assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} + + # Assert that the streaming callback was called twice + assert streaming_call_count == 2 + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): component = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 16d706b35f..7c2e05df3a 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +from haystack.dataclasses.streaming_chunk import StreamingChunk import pytest from transformers import PreTrainedTokenizer @@ -233,6 +234,46 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel assert chat_message.is_from(ChatRole.ASSISTANT) assert chat_message.text == "Berlin is cool" + def test_run_with_streaming_callback(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): + # Define the streaming callback function + def streaming_callback_fn(chunk: StreamingChunk): ... + + generator = HuggingFaceLocalChatGenerator( + model="meta-llama/Llama-2-13b-chat-hf", streaming_callback=streaming_callback_fn + ) + + # Use the mocked pipeline from the fixture and simulate warm_up + generator.pipeline = mock_pipeline_tokenizer + + results = generator.run(messages=chat_messages) + + assert "replies" in results + assert isinstance(results["replies"][0], ChatMessage) + chat_message = results["replies"][0] + assert chat_message.is_from(ChatRole.ASSISTANT) + assert chat_message.text == "Berlin is cool" + generator.pipeline.assert_called_once() + generator.pipeline.call_args[1]["streamer"].token_handler == streaming_callback_fn + + def test_run_with_streaming_callback_in_run_method(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): + # Define the streaming callback function + def streaming_callback_fn(chunk: StreamingChunk): ... + + generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") + + # Use the mocked pipeline from the fixture and simulate warm_up + generator.pipeline = mock_pipeline_tokenizer + + results = generator.run(messages=chat_messages, streaming_callback=streaming_callback_fn) + + assert "replies" in results + assert isinstance(results["replies"][0], ChatMessage) + chat_message = results["replies"][0] + assert chat_message.is_from(ChatRole.ASSISTANT) + assert chat_message.text == "Berlin is cool" + generator.pipeline.assert_called_once() + generator.pipeline.call_args[1]["streamer"].token_handler == streaming_callback_fn + @patch("haystack.components.generators.chat.hugging_face_local.convert_message_to_hf_format") def test_messages_conversion_is_called(self, mock_convert, model_info_mock): generator = HuggingFaceLocalChatGenerator(model="fake-model") From 3119ae1ec9f0a310ea98f034d1e3963d2ae14975 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Thu, 23 Jan 2025 12:40:19 +0100 Subject: [PATCH 27/41] refactor: raise `PipelineError` when `Pipeline.from_dict` receives an invalid type (#8711) * fix: error on invalid type * add reno * Update releasenotes/notes/fix-invalid-component-type-error-83ee00d820b63cc5.yaml Co-authored-by: Stefano Fiorucci * Update test/core/pipeline/test_pipeline.py Co-authored-by: Stefano Fiorucci * fix reno * fix reno * last reno fix --------- Co-authored-by: Stefano Fiorucci --- haystack/core/pipeline/base.py | 6 ++++-- ...valid-component-type-error-83ee00d820b63cc5.yaml | 5 +++++ test/core/pipeline/test_pipeline.py | 13 ++++++++++++- 3 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 releasenotes/notes/fix-invalid-component-type-error-83ee00d820b63cc5.yaml diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index d8f2a65932..155ba27fe5 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -167,8 +167,10 @@ def from_dict( f"Successfully imported module {module} but can't find it in the component registry." "This is unexpected and most likely a bug." ) - except (ImportError, PipelineError) as e: - raise PipelineError(f"Component '{component_data['type']}' not imported.") from e + except (ImportError, PipelineError, ValueError) as e: + raise PipelineError( + f"Component '{component_data['type']}' (name: '{name}') not imported." + ) from e # Create a new one component_class = component.registry[component_data["type"]] diff --git a/releasenotes/notes/fix-invalid-component-type-error-83ee00d820b63cc5.yaml b/releasenotes/notes/fix-invalid-component-type-error-83ee00d820b63cc5.yaml new file mode 100644 index 0000000000..64119cfa88 --- /dev/null +++ b/releasenotes/notes/fix-invalid-component-type-error-83ee00d820b63cc5.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + When `Pipeline.from_dict` receives an invalid type (e.g. empty string), an informative `PipelineError` is now + raised. diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 1cd57a5b5b..1f0284301e 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -564,7 +564,7 @@ def test_from_dict_without_component_type(self): err.match("Missing 'type' in component 'add_two'") # UNIT - def test_from_dict_without_registered_component_type(self, request): + def test_from_dict_without_registered_component_type(self): data = { "metadata": {"test": "test"}, "components": {"add_two": {"type": "foo.bar.baz", "init_parameters": {"add": 2}}}, @@ -575,6 +575,17 @@ def test_from_dict_without_registered_component_type(self, request): err.match(r"Component .+ not imported.") + def test_from_dict_with_invalid_type(self): + data = { + "metadata": {"test": "test"}, + "components": {"add_two": {"type": "", "init_parameters": {"add": 2}}}, + "connections": [], + } + with pytest.raises(PipelineError) as err: + Pipeline.from_dict(data) + + err.match(r"Component '' \(name: 'add_two'\) not imported.") + # UNIT def test_from_dict_without_connection_sender(self): data = {"metadata": {"test": "test"}, "components": {}, "connections": [{"receiver": "some.receiver"}]} From 223373eced28772a1f541e82ed60a59df16e37c8 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 24 Jan 2025 11:17:47 +0100 Subject: [PATCH 28/41] fix: Document Classifiers - fix error messages (#8765) * fix: Document Classifiers - fix docstrings + error messages * grammar * fix --- .../components/classifiers/document_language_classifier.py | 2 +- .../components/classifiers/zero_shot_document_classifier.py | 4 ++-- ...fix-classifiers-docstrings-messages-dcae473d2bd3cb95.yaml | 5 +++++ 3 files changed, 8 insertions(+), 3 deletions(-) create mode 100644 releasenotes/notes/fix-classifiers-docstrings-messages-dcae473d2bd3cb95.yaml diff --git a/haystack/components/classifiers/document_language_classifier.py b/haystack/components/classifiers/document_language_classifier.py index ed9b42b5d5..5b6b3469bc 100644 --- a/haystack/components/classifiers/document_language_classifier.py +++ b/haystack/components/classifiers/document_language_classifier.py @@ -83,7 +83,7 @@ def run(self, documents: List[Document]): if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): raise TypeError( "DocumentLanguageClassifier expects a list of Document as input. " - "In case you want to classify a text, please use the TextLanguageClassifier." + "In case you want to classify and route a text, please use the TextLanguageRouter." ) output: Dict[str, List[Document]] = {language: [] for language in self.languages} diff --git a/haystack/components/classifiers/zero_shot_document_classifier.py b/haystack/components/classifiers/zero_shot_document_classifier.py index 4be0a66d44..017a20f3c5 100644 --- a/haystack/components/classifiers/zero_shot_document_classifier.py +++ b/haystack/components/classifiers/zero_shot_document_classifier.py @@ -211,8 +211,8 @@ def run(self, documents: List[Document], batch_size: int = 1): if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): raise TypeError( - "DocumentLanguageClassifier expects a list of documents as input. " - "In case you want to classify a text, please use the TextLanguageClassifier." + "TransformerZeroShotDocumentClassifier expects a list of documents as input. " + "In case you want to classify and route a text, please use the TransformersZeroShotTextRouter." ) invalid_doc_ids = [] diff --git a/releasenotes/notes/fix-classifiers-docstrings-messages-dcae473d2bd3cb95.yaml b/releasenotes/notes/fix-classifiers-docstrings-messages-dcae473d2bd3cb95.yaml new file mode 100644 index 0000000000..32d36192ff --- /dev/null +++ b/releasenotes/notes/fix-classifiers-docstrings-messages-dcae473d2bd3cb95.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fix error messages for Document Classifier components, that suggested using nonexistent components for text + classification. From c989d9c483a4fc74147e17065779fd6103f4a084 Mon Sep 17 00:00:00 2001 From: Night-Quiet Date: Fri, 24 Jan 2025 19:06:09 +0800 Subject: [PATCH 29/41] fix: skip comment blocks in `DOCXToDocument` (#8764) * fix bug #8759 * Apply suggestions from code review * release note --------- Co-authored-by: Stefano Fiorucci --- haystack/components/converters/docx.py | 3 +++ .../notes/docx-skip-comment-blocks-d3a555d0324788c7.yaml | 4 ++++ 2 files changed, 7 insertions(+) create mode 100644 releasenotes/notes/docx-skip-comment-blocks-d3a555d0324788c7.yaml diff --git a/haystack/components/converters/docx.py b/haystack/components/converters/docx.py index b9d59bd564..8f9a58004d 100644 --- a/haystack/components/converters/docx.py +++ b/haystack/components/converters/docx.py @@ -23,6 +23,7 @@ from docx.document import Document as DocxDocument from docx.table import Table from docx.text.paragraph import Paragraph + from lxml.etree import _Comment @dataclass @@ -210,6 +211,8 @@ def _extract_elements(self, document: "DocxDocument") -> List[str]: """ elements = [] for element in document.element.body: + if isinstance(element, _Comment): + continue if element.tag.endswith("p"): paragraph = Paragraph(element, document) if paragraph.contains_page_break: diff --git a/releasenotes/notes/docx-skip-comment-blocks-d3a555d0324788c7.yaml b/releasenotes/notes/docx-skip-comment-blocks-d3a555d0324788c7.yaml new file mode 100644 index 0000000000..e213aa6949 --- /dev/null +++ b/releasenotes/notes/docx-skip-comment-blocks-d3a555d0324788c7.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + The DOCXToDocument component now skips comment blocks in DOCX files that previously caused errors. From 0ac47b00640d6e9c3ac640838c70543005846ca8 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 27 Jan 2025 11:55:18 +0100 Subject: [PATCH 30/41] pin numba>=0.54.0 (#8773) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 397ce83583..72b8687c34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ format-check = "ruff format --check {args}" [tool.hatch.envs.test] extra-dependencies = [ "numpy>=2", # Haystack is compatible both with numpy 1.x and 2.x, but we test with 2.x + "numba>=0.54.0", # This pin helps uv resolve the dependency tree. See https://github.com/astral-sh/uv/issues/7881 "transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders From 0e6d2a4c39322e08312bbec5d9368bb8284b33db Mon Sep 17 00:00:00 2001 From: Per Lunnemann Hansen Date: Mon, 27 Jan 2025 14:52:24 +0100 Subject: [PATCH 31/41] fix: update component registration to use new class reference (#8715) The pyright language server is now able to resolve the import and provide completions for the component. Co-authored-by: Michele Pangrazzi --- haystack/core/component/component.py | 16 +++++++++------- .../component-registration-052467bb409d2e4c.yaml | 4 ++++ 2 files changed, 13 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/component-registration-052467bb409d2e4c.yaml diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index d77fd77593..740b3d154e 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -512,10 +512,12 @@ def copy_class_namespace(namespace): # We must explicitly redefine the type of the class to make sure language servers # and type checkers understand that the class is of the correct type. # mypy doesn't like that we do this though so we explicitly ignore the type check. - cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace) # type: ignore[no-redef] + new_cls: cls.__name__ = new_class( + cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace + ) # type: ignore[no-redef] # Save the component in the class registry (for deserialization) - class_path = f"{cls.__module__}.{cls.__name__}" + class_path = f"{new_cls.__module__}.{new_cls.__name__}" if class_path in self.registry: # Corner case, but it may occur easily in notebooks when re-running cells. logger.debug( @@ -523,15 +525,15 @@ def copy_class_namespace(namespace): new imported from '{new_module_name}'", component=class_path, module_name=self.registry[class_path], - new_module_name=cls, + new_module_name=new_cls, ) - self.registry[class_path] = cls - logger.debug("Registered Component {component}", component=cls) + self.registry[class_path] = new_cls + logger.debug("Registered Component {component}", component=new_cls) # Override the __repr__ method with a default one - cls.__repr__ = _component_repr + new_cls.__repr__ = _component_repr - return cls + return new_cls def __call__(self, cls: Optional[type] = None): # We must wrap the call to the decorator in a function for it to work diff --git a/releasenotes/notes/component-registration-052467bb409d2e4c.yaml b/releasenotes/notes/component-registration-052467bb409d2e4c.yaml new file mode 100644 index 0000000000..a3c575f656 --- /dev/null +++ b/releasenotes/notes/component-registration-052467bb409d2e4c.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixes a bug that causes pyright type checker to fail for all component objects. From e3dc1646255e3c6338bfc4c3cbe025ed280b3046 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 28 Jan 2025 01:03:23 -0800 Subject: [PATCH 32/41] Update license-header.txt with breaking changes from hawkeye (#8778) --- license-header.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/license-header.txt b/license-header.txt index ec46732299..df5ebe365d 100644 --- a/license-header.txt +++ b/license-header.txt @@ -1,3 +1,3 @@ -SPDX-FileCopyrightText: ${inceptionYear}-present ${copyrightOwner} +SPDX-FileCopyrightText: {{ props["inceptionYear"] }}-present {{ props["copyrightOwner"] }} SPDX-License-Identifier: Apache-2.0 From bba84e551721c1bda779968b3c09247653af2dbd Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 28 Jan 2025 01:29:55 -0800 Subject: [PATCH 33/41] fix: Fix JSONConverter to properly skip files that are not utf-8 encoded (#8775) * Small fix * Add reno * Trying out license header fix here --- haystack/components/converters/json.py | 1 + ...n-converter-non-utf8-3a755df732a8cbd5.yaml | 4 ++++ test/components/converters/test_json.py | 19 +++++++++++++++++++ 3 files changed, 24 insertions(+) create mode 100644 releasenotes/notes/fix-json-converter-non-utf8-3a755df732a8cbd5.yaml diff --git a/haystack/components/converters/json.py b/haystack/components/converters/json.py index 3a8c6f52f0..6d3781e4e9 100644 --- a/haystack/components/converters/json.py +++ b/haystack/components/converters/json.py @@ -194,6 +194,7 @@ def _get_content_and_meta(self, source: ByteStream) -> List[Tuple[str, Dict[str, source=source.meta["file_path"], error=exc, ) + return [] meta_fields = self._meta_fields or set() diff --git a/releasenotes/notes/fix-json-converter-non-utf8-3a755df732a8cbd5.yaml b/releasenotes/notes/fix-json-converter-non-utf8-3a755df732a8cbd5.yaml new file mode 100644 index 0000000000..2c475d201e --- /dev/null +++ b/releasenotes/notes/fix-json-converter-non-utf8-3a755df732a8cbd5.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixed JSONConverter to properly skip converting JSON files that are not utf-8 encoded. diff --git a/test/components/converters/test_json.py b/test/components/converters/test_json.py index f9dcf2fa0c..5419fa812a 100644 --- a/test/components/converters/test_json.py +++ b/test/components/converters/test_json.py @@ -236,6 +236,25 @@ def test_run_with_bad_filter(tmpdir, caplog): assert result == {"documents": []} +def test_run_with_bad_encoding(tmpdir, caplog): + test_file = Path(tmpdir / "test_file.json") + test_file.write_text(json.dumps(test_data[0]), "utf-16") + + sources = [test_file] + converter = JSONConverter(".laureates") + + caplog.clear() + with caplog.at_level(logging.WARNING): + result = converter.run(sources=sources) + + records = caplog.records + assert len(records) == 1 + assert records[0].msg.startswith( + f"Failed to extract text from {test_file}. Skipping it. Error: 'utf-8' codec can't decode byte" + ) + assert result == {"documents": []} + + def test_run_with_single_meta(tmpdir): first_test_file = Path(tmpdir / "first_test_file.json") second_test_file = Path(tmpdir / "second_test_file.json") From d93932150563ddb270e5c531df56c250ca706c97 Mon Sep 17 00:00:00 2001 From: Ulises M <30765968+lbux@users.noreply.github.com> Date: Tue, 28 Jan 2025 03:18:54 -0800 Subject: [PATCH 34/41] fix: compress pipeline graphs before sending to mermaid (#8767) * compress graph data to support pako endpoint * Update haystack/core/pipeline/draw.py Co-authored-by: David S. Batista * Update haystack/core/pipeline/draw.py Co-authored-by: David S. Batista --------- Co-authored-by: David S. Batista --- haystack/core/pipeline/draw.py | 13 +++++++++---- .../compress-mermaid-graph-bf23918c3da6e018.yaml | 4 ++++ 2 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/compress-mermaid-graph-bf23918c3da6e018.yaml diff --git a/haystack/core/pipeline/draw.py b/haystack/core/pipeline/draw.py index 2e24bf9acd..b367696d84 100644 --- a/haystack/core/pipeline/draw.py +++ b/haystack/core/pipeline/draw.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import base64 +import json +import zlib import networkx # type:ignore import requests @@ -68,11 +70,14 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph): """ # Copy the graph to avoid modifying the original graph_styled = _to_mermaid_text(graph.copy()) + json_string = json.dumps({"code": graph_styled}) - graphbytes = graph_styled.encode("ascii") - base64_bytes = base64.b64encode(graphbytes) - base64_string = base64_bytes.decode("ascii") - url = f"https://mermaid.ink/img/{base64_string}?type=png" + # Uses the DEFLATE algorithm at the highest level for smallest size + compressor = zlib.compressobj(level=9) + compressed_data = compressor.compress(json_string.encode("utf-8")) + compressor.flush() + compressed_url_safe_base64 = base64.urlsafe_b64encode(compressed_data).decode("utf-8").strip() + + url = f"https://mermaid.ink/img/pako:{compressed_url_safe_base64}?type=png" logger.debug("Rendering graph at {url}", url=url) try: diff --git a/releasenotes/notes/compress-mermaid-graph-bf23918c3da6e018.yaml b/releasenotes/notes/compress-mermaid-graph-bf23918c3da6e018.yaml new file mode 100644 index 0000000000..2b2d531924 --- /dev/null +++ b/releasenotes/notes/compress-mermaid-graph-bf23918c3da6e018.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Haystack pipelines with Mermaid graphs are now compressed to reduce the size of the encoded base64 and avoid HTTP 400 errors when the graph is too large. From 3ef609a3e865d907379b4f1e03eef692a4e0414e Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 31 Jan 2025 10:35:15 +0100 Subject: [PATCH 35/41] temporarily pin huggingface_hub<0.28.0 (#8790) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 72b8687c34..1606df3fbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ extra-dependencies = [ "numba>=0.54.0", # This pin helps uv resolve the dependency tree. See https://github.com/astral-sh/uv/issues/7881 "transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... - "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders + "huggingface_hub>=0.27.0, <0.28.0", # Hugging Face API Generators and Embedders "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier "openai-whisper>=20231106", # LocalWhisperTranscriber From 80575a7e9c8f7989ce51db69aed5e380f3a0598d Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 31 Jan 2025 15:03:15 +0100 Subject: [PATCH 36/41] deprecate dataframe and ExtractedTableAnswer (#8789) --- haystack/dataclasses/answer.py | 5 +++++ haystack/dataclasses/document.py | 5 +++++ ...ate-dataframe-ExtractedTableAnswer-3d85649a4a7e222f.yaml | 6 ++++++ 3 files changed, 16 insertions(+) create mode 100644 releasenotes/notes/deprecate-dataframe-ExtractedTableAnswer-3d85649a4a7e222f.yaml diff --git a/haystack/dataclasses/answer.py b/haystack/dataclasses/answer.py index edf0092f55..4ba447f8b4 100644 --- a/haystack/dataclasses/answer.py +++ b/haystack/dataclasses/answer.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import io +import warnings from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional, Protocol, runtime_checkable @@ -98,6 +99,10 @@ class ExtractedTableAnswer: context_cells: List["Cell"] = field(default_factory=list) meta: Dict[str, Any] = field(default_factory=dict) + def __post_init__(self): + msg = "The `ExtractedTableAnswer` dataclass is deprecated and will be removed in Haystack 2.11.0." + warnings.warn(msg, DeprecationWarning) + @dataclass class Cell: row: int diff --git a/haystack/dataclasses/document.py b/haystack/dataclasses/document.py index aed3596411..e29effccd3 100644 --- a/haystack/dataclasses/document.py +++ b/haystack/dataclasses/document.py @@ -4,6 +4,7 @@ import hashlib import io +import warnings from dataclasses import asdict, dataclass, field, fields from typing import Any, Dict, List, Optional @@ -114,6 +115,10 @@ def __post_init__(self): # Generate an id only if not explicitly set self.id = self.id or self._create_id() + if self.dataframe is not None: + msg = "The `dataframe` field is deprecated and will be removed in Haystack 2.11.0." + warnings.warn(msg, DeprecationWarning) + def _create_id(self): """ Creates a hash of the given content that acts as the document's ID. diff --git a/releasenotes/notes/deprecate-dataframe-ExtractedTableAnswer-3d85649a4a7e222f.yaml b/releasenotes/notes/deprecate-dataframe-ExtractedTableAnswer-3d85649a4a7e222f.yaml new file mode 100644 index 0000000000..25318985f1 --- /dev/null +++ b/releasenotes/notes/deprecate-dataframe-ExtractedTableAnswer-3d85649a4a7e222f.yaml @@ -0,0 +1,6 @@ +--- +deprecations: + - | + The `ExtractedTableAnswer` dataclass and the `dataframe` field in the `Document` dataclass are deprecated and + will be removed in Haystack 2.11.0. + Check out the GitHub discussion for motivation and details: https://github.com/deepset-ai/haystack/discussions/8688 From 379711f63ee8e58181194eaf58beea052723a3f4 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 31 Jan 2025 17:01:00 +0100 Subject: [PATCH 37/41] fix: Pin nltk version for sentence tokenizer (#8786) * Pin nltk version for sentence tokenizer * Update pyproject.toml * Update haystack/components/preprocessors/sentence_tokenizer.py --------- Co-authored-by: David S. Batista --- haystack/components/preprocessors/sentence_tokenizer.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/components/preprocessors/sentence_tokenizer.py b/haystack/components/preprocessors/sentence_tokenizer.py index 9619b851fc..2cb77347d3 100644 --- a/haystack/components/preprocessors/sentence_tokenizer.py +++ b/haystack/components/preprocessors/sentence_tokenizer.py @@ -9,7 +9,7 @@ from haystack import logging from haystack.lazy_imports import LazyImport -with LazyImport("Run 'pip install nltk'") as nltk_imports: +with LazyImport("Run 'pip install nltk>=3.9.1'") as nltk_imports: import nltk logger = logging.getLogger(__name__) diff --git a/pyproject.toml b/pyproject.toml index 1606df3fbf..9a5f15070b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,7 @@ extra-dependencies = [ "openpyxl", # XLSXToDocument "tabulate", # XLSXToDocument - "nltk", # NLTKDocumentSplitter + "nltk>=3.9.1", # NLTKDocumentSplitter # OpenAPI "jsonref", # OpenAPIServiceConnector, OpenAPIServiceToFunctions From 1a91365cc8c303d154f5b2774ecb6aa2e100c9cd Mon Sep 17 00:00:00 2001 From: mathislucka Date: Mon, 3 Feb 2025 12:35:37 +0100 Subject: [PATCH 38/41] fix: callables can be deserialized from fully qualified import path (#8788) * fix: callables can be deserialized from fully qualified import path * fix: license header * fix: format * fix: types * fix? types * test: extend test case * format * add release notes --- .../callable_serialization/random_callable.py | 10 ++++++ haystack/utils/callable_serialization.py | 35 +++++++++++-------- ...able-deserialization-5a3ef204a8d07616.yaml | 4 +++ test/utils/test_callable_serialization.py | 9 +++++ 4 files changed, 43 insertions(+), 15 deletions(-) create mode 100644 haystack/testing/callable_serialization/random_callable.py create mode 100644 releasenotes/notes/fix-callable-deserialization-5a3ef204a8d07616.yaml diff --git a/haystack/testing/callable_serialization/random_callable.py b/haystack/testing/callable_serialization/random_callable.py new file mode 100644 index 0000000000..42702ec7e2 --- /dev/null +++ b/haystack/testing/callable_serialization/random_callable.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +def callable_to_deserialize(hello: str) -> str: + """ + A function to test callable deserialization. + """ + return f"{hello}, world!" diff --git a/haystack/utils/callable_serialization.py b/haystack/utils/callable_serialization.py index 3e4f947e8c..a0f51de048 100644 --- a/haystack/utils/callable_serialization.py +++ b/haystack/utils/callable_serialization.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from typing import Callable +from typing import Any, Callable from haystack.core.errors import DeserializationError, SerializationError from haystack.utils.type_serialization import thread_safe_import @@ -50,26 +50,31 @@ def deserialize_callable(callable_handle: str) -> Callable: :return: The callable :raises DeserializationError: If the callable cannot be found """ - module_name, *attribute_chain = callable_handle.split(".") + parts = callable_handle.split(".") - try: - current = thread_safe_import(module_name) - except Exception as e: - raise DeserializationError(f"Could not locate the module: {module_name}") from e - - for attr in attribute_chain: + for i in range(len(parts), 0, -1): + module_name = ".".join(parts[:i]) try: - attr_value = getattr(current, attr) - except AttributeError as e: - raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e + mod: Any = thread_safe_import(module_name) + except Exception: + # keep reducing i until we find a valid module import + continue + + attr_value = mod + for part in parts[i:]: + try: + attr_value = getattr(attr_value, part) + except AttributeError as e: + raise DeserializationError(f"Could not find attribute '{part}' in {attr_value.__name__}") from e # when the attribute is a classmethod, we need the underlying function if isinstance(attr_value, (classmethod, staticmethod)): attr_value = attr_value.__func__ - current = attr_value + if not callable(attr_value): + raise DeserializationError(f"The final attribute is not callable: {attr_value}") - if not callable(current): - raise DeserializationError(f"The final attribute is not callable: {current}") + return attr_value - return current + # Fallback if we never find anything + raise DeserializationError(f"Could not import '{callable_handle}' as a module or callable.") diff --git a/releasenotes/notes/fix-callable-deserialization-5a3ef204a8d07616.yaml b/releasenotes/notes/fix-callable-deserialization-5a3ef204a8d07616.yaml new file mode 100644 index 0000000000..0993428fad --- /dev/null +++ b/releasenotes/notes/fix-callable-deserialization-5a3ef204a8d07616.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Callable deserialization now works for all fully qualified import paths. diff --git a/test/utils/test_callable_serialization.py b/test/utils/test_callable_serialization.py index 4f75ddd0ad..2176d8c742 100644 --- a/test/utils/test_callable_serialization.py +++ b/test/utils/test_callable_serialization.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import pytest import requests from haystack.core.errors import DeserializationError, SerializationError from haystack.components.generators.utils import print_streaming_chunk +from haystack.testing.callable_serialization.random_callable import callable_to_deserialize from haystack.utils import serialize_callable, deserialize_callable @@ -40,6 +42,13 @@ def test_callable_serialization_non_local(): assert result == "requests.api.get" +def test_fully_qualified_import_deserialization(): + func = deserialize_callable("haystack.testing.callable_serialization.random_callable.callable_to_deserialize") + + assert func is callable_to_deserialize + assert func("Hello") == "Hello, world!" + + def test_callable_serialization_instance_methods_fail(): with pytest.raises(SerializationError): serialize_callable(TestClass.my_method) From 503d275ade85e04efed787ec421bdc55a5a77abf Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 3 Feb 2025 12:47:14 +0100 Subject: [PATCH 39/41] chore: remove DocumentSplitter warning related to split_by='sentence' --- haystack/components/preprocessors/document_splitter.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/haystack/components/preprocessors/document_splitter.py b/haystack/components/preprocessors/document_splitter.py index 949f756ae6..96f30b2a55 100644 --- a/haystack/components/preprocessors/document_splitter.py +++ b/haystack/components/preprocessors/document_splitter.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import warnings from copy import deepcopy from typing import Any, Callable, Dict, List, Literal, Optional, Tuple @@ -112,14 +111,6 @@ def __init__( # pylint: disable=too-many-positional-arguments nltk_imports.check() self.sentence_splitter = None - if split_by == "sentence": - # ToDo: remove this warning in the next major release - msg = ( - "The `split_by='sentence'` no longer splits by '.' and now relies on custom sentence tokenizer " - "based on NLTK. To achieve the previous behaviour use `split_by='period'." - ) - warnings.warn(msg) - def _init_checks( self, *, From f1652121acf1345f157139a67464f8d3b1472688 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 3 Feb 2025 15:55:29 +0100 Subject: [PATCH 40/41] feat: Add support for custom (or offline) Mermaid.ink server and support all parameters (#8799) * compress graph data to support pako endpoint * support mermaid.ink parameters and custom servers * dont try to resolve conflicts with the github web ui... * avoid double graph copy * fixing typing, improving docstrings and release notes * reverting type * nit - force type checker no cache * nit - force type checker no cache --------- Co-authored-by: Ulises M Co-authored-by: Ulises M <30765968+lbux@users.noreply.github.com> --- haystack/core/pipeline/base.py | 63 +++++++-- haystack/core/pipeline/draw.py | 133 ++++++++++++++++-- ...id-server-and-params-b88ca837375c3e0f.yaml | 5 + test/core/pipeline/test_draw.py | 112 ++++++++++++++- 4 files changed, 286 insertions(+), 27 deletions(-) create mode 100644 releasenotes/notes/custom-mermaid-server-and-params-b88ca837375c3e0f.yaml diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 155ba27fe5..8511a9ca38 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -34,7 +34,7 @@ DEFAULT_MARSHALLER = YamlMarshaller() -# We use a generic type to annotate the return value of classmethods, +# We use a generic type to annotate the return value of class methods, # so that static analyzers won't be confused when derived classes # use those methods. T = TypeVar("T", bound="PipelineBase") @@ -619,31 +619,76 @@ def outputs(self, include_components_with_connected_outputs: bool = False) -> Di } return outputs - def show(self) -> None: + def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None: """ - If running in a Jupyter notebook, display an image representing this `Pipeline`. + Display an image representing this `Pipeline` in a Jupyter notebook. + This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in + the notebook. + + :param server_url: + The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink'). + See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more + info on how to set up your own Mermaid server. + + :param params: + Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details + Supported keys: + - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'. + - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'. + - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'. + - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white'). + - width: Width of the output image (integer). + - height: Height of the output image (integer). + - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified. + - fit: Whether to fit the diagram size to the page (PDF only, boolean). + - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true. + - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true. + + :raises PipelineDrawingError: + If the function is called outside of a Jupyter notebook or if there is an issue with rendering. """ if is_in_jupyter(): from IPython.display import Image, display # type: ignore - image_data = _to_mermaid_image(self.graph) - + image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params) display(Image(image_data)) else: msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally." raise PipelineDrawingError(msg) - def draw(self, path: Path) -> None: + def draw(self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None: """ - Save an image representing this `Pipeline` to `path`. + Save an image representing this `Pipeline` to the specified file path. + + This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path. :param path: - The path to save the image to. + The file path where the generated image will be saved. + :param server_url: + The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink'). + See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more + info on how to set up your own Mermaid server. + :param params: + Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details + Supported keys: + - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'. + - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'. + - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'. + - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white'). + - width: Width of the output image (integer). + - height: Height of the output image (integer). + - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified. + - fit: Whether to fit the diagram size to the page (PDF only, boolean). + - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true. + - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true. + + :raises PipelineDrawingError: + If there is an issue with rendering or saving the image. """ # Before drawing we edit a bit the graph, to avoid modifying the original that is # used for running the pipeline we copy it. - image_data = _to_mermaid_image(self.graph) + image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params) Path(path).write_bytes(image_data) def walk(self) -> Iterator[Tuple[str, Component]]: diff --git a/haystack/core/pipeline/draw.py b/haystack/core/pipeline/draw.py index b367696d84..51cec9f6bc 100644 --- a/haystack/core/pipeline/draw.py +++ b/haystack/core/pipeline/draw.py @@ -5,6 +5,7 @@ import base64 import json import zlib +from typing import Any, Dict, Optional import networkx # type:ignore import requests @@ -54,7 +55,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph: ARROWHEAD_MANDATORY = "-->" ARROWHEAD_OPTIONAL = ".->" MERMAID_STYLED_TEMPLATE = """ -%%{{ init: {{'theme': 'neutral' }} }}%% +%%{{ init: {params} }}%% graph TD; @@ -64,27 +65,133 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph: """ -def _to_mermaid_image(graph: networkx.MultiDiGraph): +def _validate_mermaid_params(params: Dict[str, Any]) -> None: """ - Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access. + Validates and sets default values for Mermaid parameters. + + :param params: + Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details. + Supported keys: + - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'. + - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'. + - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'. + - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white'). + - width: Width of the output image (integer). + - height: Height of the output image (integer). + - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified. + - fit: Whether to fit the diagram size to the page (PDF only, boolean). + - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true. + - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true. + + :raises ValueError: + If any parameter is invalid or does not match the expected format. + """ + valid_img_types = {"jpeg", "png", "webp"} + valid_themes = {"default", "neutral", "dark", "forest"} + valid_formats = {"img", "svg", "pdf"} + + params.setdefault("format", "img") + params.setdefault("type", "png") + params.setdefault("theme", "neutral") + + if params["format"] not in valid_formats: + raise ValueError(f"Invalid image format: {params['format']}. Valid options are: {valid_formats}.") + + if params["format"] == "img" and params["type"] not in valid_img_types: + raise ValueError(f"Invalid image type: {params['type']}. Valid options are: {valid_img_types}.") + + if params["theme"] not in valid_themes: + raise ValueError(f"Invalid theme: {params['theme']}. Valid options are: {valid_themes}.") + + if "width" in params and not isinstance(params["width"], int): + raise ValueError("Width must be an integer.") + if "height" in params and not isinstance(params["height"], int): + raise ValueError("Height must be an integer.") + + if "scale" in params and not 1 <= params["scale"] <= 3: + raise ValueError("Scale must be a number between 1 and 3.") + if "scale" in params and not ("width" in params or "height" in params): + raise ValueError("Scale is only allowed when width or height is set.") + + if "bgColor" in params and not isinstance(params["bgColor"], str): + raise ValueError("Background color must be a string.") + + # PDF specific parameters + if params["format"] == "pdf": + if "fit" in params and not isinstance(params["fit"], bool): + raise ValueError("Fit must be a boolean.") + if "paper" in params and not isinstance(params["paper"], str): + raise ValueError("Paper size must be a string (e.g., 'a4', 'a3').") + if "landscape" in params and not isinstance(params["landscape"], bool): + raise ValueError("Landscape must be a boolean.") + if "fit" in params and ("paper" in params or "landscape" in params): + logger.warning("`fit` overrides `paper` and `landscape` for PDFs. Ignoring `paper` and `landscape`.") + + +def _to_mermaid_image( + graph: networkx.MultiDiGraph, server_url: str = "https://mermaid.ink", params: Optional[dict] = None +) -> bytes: + """ + Renders a pipeline using a Mermaid server. + + :param graph: + The graph to render as a Mermaid pipeline. + :param server_url: + Base URL of the Mermaid server (default: 'https://mermaid.ink'). + :param params: + Dictionary of customization parameters. See `validate_mermaid_params` for valid keys. + :returns: + The image, SVG, or PDF data returned by the Mermaid server as bytes. + :raises ValueError: + If any parameter is invalid or does not match the expected format. + :raises PipelineDrawingError: + If there is an issue connecting to the Mermaid server or the server returns an error. """ + + if params is None: + params = {} + + _validate_mermaid_params(params) + + theme = params.get("theme") + init_params = json.dumps({"theme": theme}) + # Copy the graph to avoid modifying the original - graph_styled = _to_mermaid_text(graph.copy()) + graph_styled = _to_mermaid_text(graph.copy(), init_params) json_string = json.dumps({"code": graph_styled}) - # Uses the DEFLATE algorithm at the highest level for smallest size - compressor = zlib.compressobj(level=9) + # Compress the JSON string with zlib (RFC 1950) + compressor = zlib.compressobj(level=9, wbits=15) compressed_data = compressor.compress(json_string.encode("utf-8")) + compressor.flush() compressed_url_safe_base64 = base64.urlsafe_b64encode(compressed_data).decode("utf-8").strip() - url = f"https://mermaid.ink/img/pako:{compressed_url_safe_base64}?type=png" + # Determine the correct endpoint + endpoint_format = params.get("format", "img") # Default to /img endpoint + if endpoint_format not in {"img", "svg", "pdf"}: + raise ValueError(f"Invalid format: {endpoint_format}. Valid options are 'img', 'svg', or 'pdf'.") + + # Construct the URL without query parameters + url = f"{server_url}/{endpoint_format}/pako:{compressed_url_safe_base64}" + + # Add query parameters adhering to mermaid.ink documentation + query_params = [] + for key, value in params.items(): + if key not in {"theme", "format"}: # Exclude theme (handled in init_params) and format (endpoint-specific) + if value is True: + query_params.append(f"{key}") + else: + query_params.append(f"{key}={value}") + + if query_params: + url += "?" + "&".join(query_params) logger.debug("Rendering graph at {url}", url=url) try: resp = requests.get(url, timeout=10) if resp.status_code >= 400: logger.warning( - "Failed to draw the pipeline: https://mermaid.ink/img/ returned status {status_code}", + "Failed to draw the pipeline: {server_url} returned status {status_code}", + server_url=server_url, status_code=resp.status_code, ) logger.info("Exact URL requested: {url}", url=url) @@ -93,18 +200,16 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph): except Exception as exc: # pylint: disable=broad-except logger.warning( - "Failed to draw the pipeline: could not connect to https://mermaid.ink/img/ ({error})", error=exc + "Failed to draw the pipeline: could not connect to {server_url} ({error})", server_url=server_url, error=exc ) logger.info("Exact URL requested: {url}", url=url) logger.warning("No pipeline diagram will be saved.") - raise PipelineDrawingError( - "There was an issue with https://mermaid.ink/, see the stacktrace for details." - ) from exc + raise PipelineDrawingError(f"There was an issue with {server_url}, see the stacktrace for details.") from exc return resp.content -def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: +def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str: """ Converts a Networkx graph into Mermaid syntax. @@ -153,7 +258,7 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: ] connections = "\n".join(connections_list + input_connections + output_connections) - graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections) + graph_styled = MERMAID_STYLED_TEMPLATE.format(params=init_params, connections=connections) logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled) return graph_styled diff --git a/releasenotes/notes/custom-mermaid-server-and-params-b88ca837375c3e0f.yaml b/releasenotes/notes/custom-mermaid-server-and-params-b88ca837375c3e0f.yaml new file mode 100644 index 0000000000..436de3e1b8 --- /dev/null +++ b/releasenotes/notes/custom-mermaid-server-and-params-b88ca837375c3e0f.yaml @@ -0,0 +1,5 @@ +--- + +features: + - | + Drawing pipelines, i.e.: calls to draw() or show(), can now be done using a custom Mermaid server and additional parameters. This allows for more flexibility in how pipelines are rendered. See Mermaid.ink's [documentation](https://github.com/jihchi/mermaid.ink) for more information on how to set up a custom server. diff --git a/test/core/pipeline/test_draw.py b/test/core/pipeline/test_draw.py index f687f6c587..a54bb761bd 100644 --- a/test/core/pipeline/test_draw.py +++ b/test/core/pipeline/test_draw.py @@ -57,7 +57,7 @@ def raise_for_status(self): mock_response.raise_for_status = raise_for_status mock_get.return_value = mock_response - with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink/"): + with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink"): _to_mermaid_image(pipe.graph) @@ -68,11 +68,12 @@ def test_to_mermaid_text(): pipe.connect("comp1.result", "comp2.value") pipe.connect("comp2.value", "comp1.value") - text = _to_mermaid_text(pipe.graph) + init_params = {"theme": "neutral"} + text = _to_mermaid_text(pipe.graph, init_params) assert ( text == """ -%%{ init: {'theme': 'neutral' } }%% +%%{ init: {'theme': 'neutral'} }%% graph TD; @@ -92,5 +93,108 @@ def test_to_mermaid_text_does_not_edit_graph(): pipe.connect("comp2.value", "comp1.value") expected_pipe = pipe.to_dict() - _to_mermaid_text(pipe.graph) + init_params = {"theme": "neutral"} + _to_mermaid_text(pipe.graph, init_params) assert expected_pipe == pipe.to_dict() + + +@pytest.mark.integration +@pytest.mark.parametrize( + "params", + [ + {"format": "img", "type": "png", "theme": "dark"}, + {"format": "svg", "theme": "forest"}, + {"format": "pdf", "fit": True, "theme": "neutral"}, + ], +) +def test_to_mermaid_image_valid_formats(params): + # Test valid formats + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + image_data = _to_mermaid_image(pipe.graph, params=params) + assert image_data # Ensure some data is returned + + +def test_to_mermaid_image_invalid_format(): + # Test invalid format + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + with pytest.raises(ValueError, match="Invalid image format:"): + _to_mermaid_image(pipe.graph, params={"format": "invalid_format"}) + + +@pytest.mark.integration +def test_to_mermaid_image_missing_theme(): + # Test default theme (neutral) + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + params = {"format": "img"} + image_data = _to_mermaid_image(pipe.graph, params=params) + + assert image_data # Ensure some data is returned + + +def test_to_mermaid_image_invalid_scale(): + # Test invalid scale + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + with pytest.raises(ValueError, match="Scale must be a number between 1 and 3."): + _to_mermaid_image(pipe.graph, params={"format": "img", "scale": 5}) + + +def test_to_mermaid_image_scale_without_dimensions(): + # Test scale without width/height + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + with pytest.raises(ValueError, match="Scale is only allowed when width or height is set."): + _to_mermaid_image(pipe.graph, params={"format": "img", "scale": 2}) + + +@patch("haystack.core.pipeline.draw.requests.get") +def test_to_mermaid_image_server_error(mock_get): + # Test server failure + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + def raise_for_status(self): + raise requests.HTTPError() + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.content = '{"error": "server error"}' + mock_response.raise_for_status = raise_for_status + mock_get.return_value = mock_response + + with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink"): + _to_mermaid_image(pipe.graph) + + +def test_to_mermaid_image_invalid_server_url(): + # Test invalid server URL + pipe = Pipeline() + pipe.add_component("comp1", AddFixedValue(add=3)) + pipe.add_component("comp2", Double()) + pipe.connect("comp1.result", "comp2.value") + pipe.connect("comp2.value", "comp1.value") + + server_url = "https://invalid.server" + + with pytest.raises(PipelineDrawingError, match=f"There was an issue with {server_url}"): + _to_mermaid_image(pipe.graph, server_url=server_url) From 877f826da01730c977f197de8e5cbf5cc7e36538 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 3 Feb 2025 16:11:16 +0100 Subject: [PATCH 41/41] refactor: HF API Embedders - use `InferenceClient.feature_extraction` instead of `InferenceClient.post` (#8794) * HF API Embedders: refactoring * rename variables * rm leftovers * rm pin * rm unused import * relnote * warning with truncate/normalize and serverless inference API * test that warnings are raised --- .../hugging_face_api_document_embedder.py | 36 ++++++++--- .../hugging_face_api_text_embedder.py | 35 +++++++--- pyproject.toml | 2 +- ...s-feature-extraction-ea0421a8f76052f0.yaml | 5 ++ ...test_hugging_face_api_document_embedder.py | 64 ++++++++++++++----- .../test_hugging_face_api_text_embedder.py | 44 +++++++++---- 6 files changed, 139 insertions(+), 47 deletions(-) create mode 100644 releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 459e386976..d3b92fb74c 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -import json +import warnings from typing import Any, Dict, List, Optional, Union from tqdm import tqdm @@ -96,8 +96,8 @@ def __init__( token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), prefix: str = "", suffix: str = "", - truncate: bool = True, - normalize: bool = False, + truncate: Optional[bool] = True, + normalize: Optional[bool] = False, batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, @@ -124,13 +124,11 @@ def __init__( Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `True` and cannot be changed. :param normalize: Normalizes the embeddings to unit length. Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `False` and cannot be changed. :param batch_size: Number of documents to process at once. :param progress_bar: @@ -239,18 +237,36 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[ """ Embed a list of texts in batches. """ + truncate = self.truncate + normalize = self.normalize + + if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + if truncate is not None: + msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + truncate = None + if normalize is not None: + msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + normalize = None all_embeddings = [] for i in tqdm( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i : i + batch_size] - response = self._client.post( - json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize}, - task="feature-extraction", + + np_embeddings = self._client.feature_extraction( + # this method does not officially support list of strings, but works as expected + text=batch, # type: ignore[arg-type] + truncate=truncate, + normalize=normalize, ) - embeddings = json.loads(response.decode()) - all_embeddings.extend(embeddings) + + if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): + raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}") + + all_embeddings.extend(np_embeddings.tolist()) return all_embeddings diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index 2cd68d34da..535d3a9430 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -import json +import warnings from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -80,8 +80,8 @@ def __init__( token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), prefix: str = "", suffix: str = "", - truncate: bool = True, - normalize: bool = False, + truncate: Optional[bool] = True, + normalize: Optional[bool] = False, ): # pylint: disable=too-many-positional-arguments """ Creates a HuggingFaceAPITextEmbedder component. @@ -104,13 +104,11 @@ def __init__( Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `True` and cannot be changed. :param normalize: Normalizes the embeddings to unit length. Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `False` and cannot be changed. """ huggingface_hub_import.check() @@ -198,12 +196,29 @@ def run(self, text: str): "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder." ) + truncate = self.truncate + normalize = self.normalize + + if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + if truncate is not None: + msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + truncate = None + if normalize is not None: + msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + normalize = None + text_to_embed = self.prefix + text + self.suffix - response = self._client.post( - json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize}, - task="feature-extraction", - ) - embedding = json.loads(response.decode())[0] + np_embedding = self._client.feature_extraction(text=text_to_embed, truncate=truncate, normalize=normalize) + + error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}" + if np_embedding.ndim > 2: + raise ValueError(error_msg) + if np_embedding.ndim == 2 and np_embedding.shape[0] != 1: + raise ValueError(error_msg) + + embedding = np_embedding.flatten().tolist() return {"embedding": embedding} diff --git a/pyproject.toml b/pyproject.toml index 9a5f15070b..eda943c19a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ extra-dependencies = [ "numba>=0.54.0", # This pin helps uv resolve the dependency tree. See https://github.com/astral-sh/uv/issues/7881 "transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... - "huggingface_hub>=0.27.0, <0.28.0", # Hugging Face API Generators and Embedders + "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier "openai-whisper>=20231106", # LocalWhisperTranscriber diff --git a/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml b/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml new file mode 100644 index 0000000000..baf9a890aa --- /dev/null +++ b/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + In the Hugging Face API embedders, the `InferenceClient.feature_extraction` method is now used instead of + `InferenceClient.post` to compute embeddings. This ensures a more robust and future-proof implementation. diff --git a/test/components/embedders/test_hugging_face_api_document_embedder.py b/test/components/embedders/test_hugging_face_api_document_embedder.py index b9332d5363..9d452b02ca 100644 --- a/test/components/embedders/test_hugging_face_api_document_embedder.py +++ b/test/components/embedders/test_hugging_face_api_document_embedder.py @@ -8,6 +8,8 @@ import pytest from huggingface_hub.utils import RepositoryNotFoundError +from numpy import array + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder from haystack.dataclasses import Document from haystack.utils.auth import Secret @@ -23,8 +25,8 @@ def mock_check_valid_model(): yield mock -def mock_embedding_generation(json, **kwargs): - response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode() +def mock_embedding_generation(text, **kwargs): + response = array([[random.random() for _ in range(384)] for _ in range(len(text))]) return response @@ -201,10 +203,10 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): "my_prefix document number 4 my_suffix", ] - def test_embed_batch(self, mock_check_valid_model): + def test_embed_batch(self, mock_check_valid_model, recwarn): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( @@ -223,6 +225,40 @@ def test_embed_batch(self, mock_check_valid_model): assert len(embedding) == 384 assert all(isinstance(x, float) for x in embedding) + # Check that warnings about ignoring truncate and normalize are raised + assert len(recwarn) == 2 + assert "truncate" in str(recwarn[0].message) + assert "normalize" in str(recwarn[1].message) + + def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model): + texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] + + # embedding ndim != 2 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([0.1, 0.2, 0.3]) + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + + with pytest.raises(ValueError): + embedder._embed_batch(texts_to_embed=texts, batch_size=2) + + # embedding ndim == 2 but shape[0] != len(batch) + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + + with pytest.raises(ValueError): + embedder._embed_batch(texts_to_embed=texts, batch_size=2) + def test_run_wrong_input_format(self, mock_check_valid_model): embedder = HuggingFaceAPIDocumentEmbedder( api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} @@ -252,7 +288,7 @@ def test_run(self, mock_check_valid_model): Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( @@ -268,16 +304,14 @@ def test_run(self, mock_check_valid_model): result = embedder.run(documents=docs) mock_embedding_patch.assert_called_once_with( - json={ - "inputs": [ - "prefix Cuisine | I love cheese suffix", - "prefix ML | A transformer is a deep learning architecture suffix", - ], - "truncate": True, - "normalize": False, - }, - task="feature-extraction", + text=[ + "prefix Cuisine | I love cheese suffix", + "prefix ML | A transformer is a deep learning architecture suffix", + ], + truncate=None, + normalize=None, ) + documents_with_embeddings = result["documents"] assert isinstance(documents_with_embeddings, list) @@ -294,7 +328,7 @@ def test_run_custom_batch_size(self, mock_check_valid_model): Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( diff --git a/test/components/embedders/test_hugging_face_api_text_embedder.py b/test/components/embedders/test_hugging_face_api_text_embedder.py index 6e699fca25..84b2d6e83c 100644 --- a/test/components/embedders/test_hugging_face_api_text_embedder.py +++ b/test/components/embedders/test_hugging_face_api_text_embedder.py @@ -7,7 +7,7 @@ import random import pytest from huggingface_hub.utils import RepositoryNotFoundError - +from numpy import array from haystack.components.embedders import HuggingFaceAPITextEmbedder from haystack.utils.auth import Secret from haystack.utils.hf import HFEmbeddingAPIType @@ -21,11 +21,6 @@ def mock_check_valid_model(): yield mock -def mock_embedding_generation(json, **kwargs): - response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode() - return response - - class TestHuggingFaceAPITextEmbedder: def test_init_invalid_api_type(self): with pytest.raises(ValueError): @@ -141,9 +136,9 @@ def test_run_wrong_input_format(self, mock_check_valid_model): with pytest.raises(TypeError): embedder.run(text=list_integers_input) - def test_run(self, mock_check_valid_model): - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: - mock_embedding_patch.side_effect = mock_embedding_generation + def test_run(self, mock_check_valid_model, recwarn): + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]]) embedder = HuggingFaceAPITextEmbedder( api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, @@ -156,13 +151,40 @@ def test_run(self, mock_check_valid_model): result = embedder.run(text="The food was delicious") mock_embedding_patch.assert_called_once_with( - json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False}, - task="feature-extraction", + text="prefix The food was delicious suffix", truncate=None, normalize=None ) assert len(result["embedding"]) == 384 assert all(isinstance(x, float) for x in result["embedding"]) + # Check that warnings about ignoring truncate and normalize are raised + assert len(recwarn) == 2 + assert "truncate" in str(recwarn[0].message) + assert "normalize" in str(recwarn[1].message) + + def test_run_wrong_embedding_shape(self, mock_check_valid_model): + # embedding ndim > 2 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]) + + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + with pytest.raises(ValueError): + embedder.run(text="The food was delicious") + + # embedding ndim == 2 but shape[0] != 1 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + with pytest.raises(ValueError): + embedder.run(text="The food was delicious") + @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif(