Skip to content

Commit

Permalink
fix: callables can be deserialized from fully qualified import path (#…
Browse files Browse the repository at this point in the history
…8788)

* fix: callables can be deserialized from fully qualified import path

* fix: license header

* fix: format

* fix: types

* fix? types

* test: extend test case

* format

* add release notes
  • Loading branch information
mathislucka authored Feb 3, 2025
1 parent 379711f commit 1a91365
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
10 changes: 10 additions & 0 deletions haystack/testing/callable_serialization/random_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0


def callable_to_deserialize(hello: str) -> str:
"""
A function to test callable deserialization.
"""
return f"{hello}, world!"
35 changes: 20 additions & 15 deletions haystack/utils/callable_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Callable
from typing import Any, Callable

from haystack.core.errors import DeserializationError, SerializationError
from haystack.utils.type_serialization import thread_safe_import
Expand Down Expand Up @@ -50,26 +50,31 @@ def deserialize_callable(callable_handle: str) -> Callable:
:return: The callable
:raises DeserializationError: If the callable cannot be found
"""
module_name, *attribute_chain = callable_handle.split(".")
parts = callable_handle.split(".")

try:
current = thread_safe_import(module_name)
except Exception as e:
raise DeserializationError(f"Could not locate the module: {module_name}") from e

for attr in attribute_chain:
for i in range(len(parts), 0, -1):
module_name = ".".join(parts[:i])
try:
attr_value = getattr(current, attr)
except AttributeError as e:
raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e
mod: Any = thread_safe_import(module_name)
except Exception:
# keep reducing i until we find a valid module import
continue

attr_value = mod
for part in parts[i:]:
try:
attr_value = getattr(attr_value, part)
except AttributeError as e:
raise DeserializationError(f"Could not find attribute '{part}' in {attr_value.__name__}") from e

# when the attribute is a classmethod, we need the underlying function
if isinstance(attr_value, (classmethod, staticmethod)):
attr_value = attr_value.__func__

current = attr_value
if not callable(attr_value):
raise DeserializationError(f"The final attribute is not callable: {attr_value}")

if not callable(current):
raise DeserializationError(f"The final attribute is not callable: {current}")
return attr_value

return current
# Fallback if we never find anything
raise DeserializationError(f"Could not import '{callable_handle}' as a module or callable.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Callable deserialization now works for all fully qualified import paths.
9 changes: 9 additions & 0 deletions test/utils/test_callable_serialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import requests
from haystack.core.errors import DeserializationError, SerializationError
from haystack.components.generators.utils import print_streaming_chunk
from haystack.testing.callable_serialization.random_callable import callable_to_deserialize
from haystack.utils import serialize_callable, deserialize_callable


Expand Down Expand Up @@ -40,6 +42,13 @@ def test_callable_serialization_non_local():
assert result == "requests.api.get"


def test_fully_qualified_import_deserialization():
func = deserialize_callable("haystack.testing.callable_serialization.random_callable.callable_to_deserialize")

assert func is callable_to_deserialize
assert func("Hello") == "Hello, world!"


def test_callable_serialization_instance_methods_fail():
with pytest.raises(SerializationError):
serialize_callable(TestClass.my_method)
Expand Down

0 comments on commit 1a91365

Please sign in to comment.