Skip to content

Commit

Permalink
fix: make from dict conditional router more resilient (#8343)
Browse files Browse the repository at this point in the history
* fix: make from dict conditional router more resilient

* refactor: remove

* dos: add release notes

* fix: format
  • Loading branch information
ArzelaAscoIi authored and silvanocerza committed Sep 10, 2024
1 parent 29546bb commit c3796a1
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
9 changes: 7 additions & 2 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
for route in routes:
# output_type needs to be deserialized from a string to a type
route["output_type"] = deserialize_type(route["output_type"])
for name, filter_func in init_params.get("custom_filters", {}).items():
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None

# Since the custom_filters are typed as optional in the init signature, we catch the
# case where they are not present in the serialized data and set them to an empty dict.
custom_filters = init_params.get("custom_filters", {})
if custom_filters is not None:
for name, filter_func in custom_filters.items():
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
return default_from_dict(cls, data)

def run(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
The `from_dict` method of `ConditionalRouter` now correctly handles
the case where the `dict` passed to it contains the key `custom_filters` explicitly
set to `None`. Previously this was causing an `AttributeError`
30 changes: 30 additions & 0 deletions test/components/routers/test_conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,36 @@ def test_router_de_serialization(self):
# check that the result is the same and correct
assert result1 == result2 and result1 == {"streams": [1, 2, 3]}

def test_router_de_serialization_with_none_argument(self):
new_router = ConditionalRouter.from_dict(
{
"type": "haystack.components.routers.conditional_router.ConditionalRouter",
"init_parameters": {
"routes": [
{
"condition": "{{streams|length < 2}}",
"output": "{{query}}",
"output_type": "str",
"output_name": "query",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": "typing.List[int]",
"output_name": "streams",
},
],
"custom_filters": None,
"unsafe": False,
},
}
)

# now use both routers with the same input
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
result2 = new_router.run(**kwargs)
assert result2 == {"streams": [1, 2, 3]}

def test_router_serialization_idempotence(self):
routes = [
{
Expand Down

0 comments on commit c3796a1

Please sign in to comment.