diff --git a/haystack/utils/callable_serialization.py b/haystack/utils/callable_serialization.py index 3c6003135e..3e4f947e8c 100644 --- a/haystack/utils/callable_serialization.py +++ b/haystack/utils/callable_serialization.py @@ -3,9 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from typing import Callable, Optional +from typing import Callable -from haystack import DeserializationError +from haystack.core.errors import DeserializationError, SerializationError from haystack.utils.type_serialization import thread_safe_import @@ -16,17 +16,33 @@ def serialize_callable(callable_handle: Callable) -> str: :param callable_handle: The callable to serialize :return: The full path of the callable """ - module = inspect.getmodule(callable_handle) + try: + full_arg_spec = inspect.getfullargspec(callable_handle) + is_instance_method = bool(full_arg_spec.args and full_arg_spec.args[0] == "self") + except TypeError: + is_instance_method = False + if is_instance_method: + raise SerializationError("Serialization of instance methods is not supported.") + + # __qualname__ contains the fully qualified path we need for classmethods and staticmethods + qualname = getattr(callable_handle, "__qualname__", "") + if "" in qualname: + raise SerializationError("Serialization of lambdas is not supported.") + if "" in qualname: + raise SerializationError("Serialization of nested functions is not supported.") + + name = qualname or callable_handle.__name__ # Get the full package path of the function + module = inspect.getmodule(callable_handle) if module is not None: - full_path = f"{module.__name__}.{callable_handle.__name__}" + full_path = f"{module.__name__}.{name}" else: - full_path = callable_handle.__name__ + full_path = name return full_path -def deserialize_callable(callable_handle: str) -> Optional[Callable]: +def deserialize_callable(callable_handle: str) -> Callable: """ Deserializes a callable given its full import path as a string. @@ -34,14 +50,26 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]: :return: The callable :raises DeserializationError: If the callable cannot be found """ - parts = callable_handle.split(".") - module_name = ".".join(parts[:-1]) - function_name = parts[-1] + module_name, *attribute_chain = callable_handle.split(".") + try: - module = thread_safe_import(module_name) + current = thread_safe_import(module_name) except Exception as e: - raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e - deserialized_callable = getattr(module, function_name, None) - if not deserialized_callable: - raise DeserializationError(f"Could not locate the callable: {function_name}") - return deserialized_callable + raise DeserializationError(f"Could not locate the module: {module_name}") from e + + for attr in attribute_chain: + try: + attr_value = getattr(current, attr) + except AttributeError as e: + raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e + + # when the attribute is a classmethod, we need the underlying function + if isinstance(attr_value, (classmethod, staticmethod)): + attr_value = attr_value.__func__ + + current = attr_value + + if not callable(current): + raise DeserializationError(f"The final attribute is not callable: {current}") + + return current diff --git a/releasenotes/notes/improve-callables-serde-6aa1e23408063247.yaml b/releasenotes/notes/improve-callables-serde-6aa1e23408063247.yaml new file mode 100644 index 0000000000..54b0783e3d --- /dev/null +++ b/releasenotes/notes/improve-callables-serde-6aa1e23408063247.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Improve serialization and deserialization of callables. + We now allow serialization of classmethods and staticmethods + and explicitly prohibit serialization of instance methods, lambdas, and nested functions. diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 9b01acb134..c953404912 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -142,7 +142,7 @@ def test_to_dict(self, model_info_mock): token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"n": 5}, stop_words=["stop", "words"], - streaming_callback=lambda x: x, + streaming_callback=streaming_callback_handler, chat_template="irrelevant", ) @@ -155,6 +155,7 @@ def test_to_dict(self, model_info_mock): assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf" assert "token" not in init_params["huggingface_pipeline_kwargs"] assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} + assert init_params["streaming_callback"] == "chat.test_hugging_face_local.streaming_callback_handler" assert init_params["chat_template"] == "irrelevant" def test_from_dict(self, model_info_mock): diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 243eb36c89..677dfa812b 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest -from typing import Iterator + import logging import os -import json from datetime import datetime from openai import OpenAIError @@ -15,12 +14,11 @@ from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat import chat_completion_chunk -from openai import Stream from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret -from haystack.dataclasses import ChatMessage, Tool, ToolCall, ChatRole, TextContent +from haystack.dataclasses import ChatMessage, Tool, ToolCall from haystack.components.generators.chat.openai import OpenAIChatGenerator @@ -212,31 +210,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - component = OpenAIChatGenerator( - model="gpt-4o-mini", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gpt-4o-mini", - "organization": None, - "api_base_url": "test-base-url", - "max_retries": None, - "timeout": None, - "streaming_callback": "chat.test_openai.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - "tools": None, - "tools_strict": False, - }, - } - def test_from_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") data = { diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index 32628f7c45..e1d865c95f 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -90,28 +90,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - component = OpenAIGenerator( - model="gpt-4o-mini", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack.components.generators.openai.OpenAIGenerator", - "init_parameters": { - "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gpt-4o-mini", - "system_prompt": None, - "organization": None, - "api_base_url": "test-base-url", - "streaming_callback": "test_openai.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - def test_from_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") data = { diff --git a/test/components/preprocessors/test_document_splitter.py b/test/components/preprocessors/test_document_splitter.py index 094c17eeea..f9096239f2 100644 --- a/test/components/preprocessors/test_document_splitter.py +++ b/test/components/preprocessors/test_document_splitter.py @@ -467,9 +467,6 @@ def test_from_dict_with_splitting_function(self): Test the from_dict class method of the DocumentSplitter class when a custom splitting function is provided. """ - def custom_split(text): - return text.split(".") - data = { "type": "haystack.components.preprocessors.document_splitter.DocumentSplitter", "init_parameters": {"split_by": "function", "splitting_function": serialize_callable(custom_split)}, diff --git a/test/utils/test_callable_serialization.py b/test/utils/test_callable_serialization.py index 941aa14cdf..4f75ddd0ad 100644 --- a/test/utils/test_callable_serialization.py +++ b/test/utils/test_callable_serialization.py @@ -3,8 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest import requests - -from haystack import DeserializationError +from haystack.core.errors import DeserializationError, SerializationError from haystack.components.generators.utils import print_streaming_chunk from haystack.utils import serialize_callable, deserialize_callable @@ -13,6 +12,19 @@ def some_random_callable_for_testing(some_ignored_arg: str): pass +class TestClass: + @classmethod + def class_method(cls): + pass + + @staticmethod + def static_method(): + pass + + def my_method(self): + pass + + def test_callable_serialization(): result = serialize_callable(some_random_callable_for_testing) assert result == "test_callable_serialization.some_random_callable_for_testing" @@ -28,6 +40,28 @@ def test_callable_serialization_non_local(): assert result == "requests.api.get" +def test_callable_serialization_instance_methods_fail(): + with pytest.raises(SerializationError): + serialize_callable(TestClass.my_method) + + instance = TestClass() + with pytest.raises(SerializationError): + serialize_callable(instance.my_method) + + +def test_lambda_serialization_fail(): + with pytest.raises(SerializationError): + serialize_callable(lambda x: x) + + +def test_nested_function_serialization_fail(): + def my_fun(): + pass + + with pytest.raises(SerializationError): + serialize_callable(my_fun) + + def test_callable_deserialization(): result = serialize_callable(some_random_callable_for_testing) fn = deserialize_callable(result) @@ -40,8 +74,27 @@ def test_callable_deserialization_non_local(): assert fn is requests.api.get -def test_callable_deserialization_error(): +def test_classmethod_serialization_deserialization(): + result = serialize_callable(TestClass.class_method) + fn = deserialize_callable(result) + assert fn == TestClass.class_method + + +def test_staticmethod_serialization_deserialization(): + result = serialize_callable(TestClass.static_method) + fn = deserialize_callable(result) + assert fn == TestClass.static_method + + +def test_callable_deserialization_errors(): + # module does not exist with pytest.raises(DeserializationError): - deserialize_callable("this.is.not.a.valid.module") + deserialize_callable("nonexistent_module.function") + + # function does not exist + with pytest.raises(DeserializationError): + deserialize_callable("os.nonexistent_function") + + # attribute is not callable with pytest.raises(DeserializationError): - deserialize_callable("sys.foobar") + deserialize_callable("os.name")