Skip to content

Commit ff6d5ef

Browse files
skrawczelijahbenizzy
authored andcommitted
Refactors converting output values into strings
So both the driver and materializer need to handle converting function names and variables into strings. This is a little messy -- but it centralizes logic in common. I didn't bother with another file because I didn't know what to call it. So putting under common seems fine. Otherwise I added tests to ensure that new functionality works, and left the existing tests to ensure nothing broke.
1 parent 9c6a45b commit ff6d5ef

File tree

5 files changed

+166
-54
lines changed

5 files changed

+166
-54
lines changed

hamilton/common/__init__.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# code in this module should no depend on much
2+
from typing import Any, Callable, List, Optional, Set, Tuple, Union
3+
4+
5+
def convert_output_value(
6+
output_value: Union[str, Callable, Any], module_set: Set[str]
7+
) -> Tuple[Optional[str], Optional[str]]:
8+
"""Converts output values that one can request into strings.
9+
10+
It checks that if it's a function, it's in the passed in module set.
11+
12+
:param output_value: the value we want to convert into a string. We don't annotate driver.Variable here for
13+
import reasons.
14+
:param module_set: the set of modules functions could come from.
15+
:return: a tuple, (string value, string error). One or the other is returned, never both.
16+
"""
17+
if isinstance(output_value, str):
18+
return output_value, None
19+
elif hasattr(output_value, "name"):
20+
return output_value.name, None
21+
elif isinstance(output_value, Callable):
22+
if output_value.__module__ in module_set:
23+
return output_value.__name__, None
24+
else:
25+
return None, (
26+
f"Function {output_value.__module__}.{output_value.__name__} is a function not "
27+
f"in a "
28+
f"module given to the materializer. Valid choices are {module_set}."
29+
)
30+
else:
31+
return None, (
32+
f"Materializer dependency {output_value} is not a string, a function, or a driver.Variable."
33+
)
34+
35+
36+
def convert_output_values(
37+
output_values: List[Union[str, Callable, Any]], module_set: Set[str]
38+
) -> List[str]:
39+
"""Checks & converts outputs values to strings. This is used in building dependencies for the DAG.
40+
41+
:param output_values: the values to convert.
42+
:param module_set: the modules any functions could come from.
43+
:return: the final values
44+
:raises ValueError: if there are values that can't be used/converted.
45+
"""
46+
final_values = []
47+
errors = []
48+
for final_var in output_values:
49+
_val, _error = convert_output_value(final_var, module_set)
50+
if _val:
51+
final_values.append(_val)
52+
if _error:
53+
errors.append(_error)
54+
if errors:
55+
errors.sort()
56+
error_str = f"{len(errors)} errors encountered:\n " + "\n ".join(errors)
57+
raise ValueError(error_str)
58+
return final_values

hamilton/driver.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pandas as pd
1616

17+
from hamilton import common
1718
from hamilton.execution import executors, graph_functions, grouping, state
1819
from hamilton.io import materialization
1920

@@ -419,31 +420,8 @@ def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) -
419420
:param final_vars:
420421
:return: list of strings in the order that final_vars was provided.
421422
"""
422-
_final_vars = []
423-
errors = []
424-
module_set = {_module.__name__ for _module in self.graph_modules}
425-
for final_var in final_vars:
426-
if isinstance(final_var, str):
427-
_final_vars.append(final_var)
428-
elif isinstance(final_var, Variable):
429-
_final_vars.append(final_var.name)
430-
elif isinstance(final_var, Callable):
431-
if final_var.__module__ in module_set:
432-
_final_vars.append(final_var.__name__)
433-
else:
434-
errors.append(
435-
f"Function {final_var.__module__}.{final_var.__name__} is a function not "
436-
f"in a "
437-
f"module given to the driver. Valid choices are {module_set}."
438-
)
439-
else:
440-
errors.append(
441-
f"Final var {final_var} is not a string, a function, or a driver.Variable."
442-
)
443-
if errors:
444-
errors.sort()
445-
error_str = f"{len(errors)} errors encountered:\n " + "\n ".join(errors)
446-
raise ValueError(error_str)
423+
_module_set = {_module.__name__ for _module in self.graph_modules}
424+
_final_vars = common.convert_output_values(final_vars, _module_set)
447425
return _final_vars
448426

449427
def capture_execute_telemetry(

hamilton/io/materialization.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import sys
22
import typing
3-
from types import ModuleType
4-
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
3+
from typing import Any, Dict, List, Optional, Set, Type, Union
54

6-
from hamilton import base, graph, node
5+
from hamilton import base, common, graph, node
76
from hamilton.function_modifiers.adapters import SaveToDecorator
87
from hamilton.function_modifiers.dependencies import SingleDependency, value
98
from hamilton.graph import FunctionGraph
@@ -68,7 +67,7 @@ def __init__(
6867
self.dependencies = dependencies
6968
self.data_saver_kwargs = self._process_kwargs(data_saver_kwargs)
7069

71-
def sanitize_dependencies(self, module_set: Set[ModuleType]) -> "MaterializerFactory":
70+
def sanitize_dependencies(self, module_set: Set[str]) -> "MaterializerFactory":
7271
"""Sanitizes the dependencies to ensure they're strings.
7372
7473
This replaces the internal value for self.dependencies and returns a new object.
@@ -77,32 +76,9 @@ def sanitize_dependencies(self, module_set: Set[ModuleType]) -> "MaterializerFac
7776
:param module_set: modules that "functions" could come from if that's passed in.
7877
:return: new object with sanitized_dependencies.
7978
"""
80-
_final_vars = []
81-
errors = []
82-
for final_var in self.dependencies:
83-
if isinstance(final_var, str):
84-
_final_vars.append(final_var)
85-
elif hasattr(final_var, "name"):
86-
_final_vars.append(final_var.name)
87-
elif isinstance(final_var, Callable):
88-
if final_var.__module__ in module_set:
89-
_final_vars.append(final_var.__name__)
90-
else:
91-
errors.append(
92-
f"Function {final_var.__module__}.{final_var.__name__} is a function not "
93-
f"in a "
94-
f"module given to the materializer. Valid choices are {module_set}."
95-
)
96-
else:
97-
errors.append(
98-
f"Materializer dependency {final_var} is not a string, a function, or a driver.Variable."
99-
)
100-
if errors:
101-
errors.sort()
102-
error_str = f"{len(errors)} errors encountered:\n " + "\n ".join(errors)
103-
raise ValueError(error_str)
79+
final_vars = common.convert_output_values(self.dependencies, module_set)
10480
return MaterializerFactory(
105-
self.id, self.savers, self.result_builder, _final_vars, **self.data_saver_kwargs
81+
self.id, self.savers, self.result_builder, final_vars, **self.data_saver_kwargs
10682
)
10783

10884
@staticmethod

tests/io/test_materialization.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import dataclasses
22
from typing import Any, Collection, Dict, List, Optional, Type
33

4+
import pytest
5+
6+
import tests.resources.cyclic_functions
7+
import tests.resources.test_default_args
48
from hamilton import base, graph, node
59
from hamilton.io import materialization
610
from hamilton.io.data_adapters import DataSaver
@@ -152,3 +156,39 @@ def second_node() -> dict:
152156
assert "materializer_2" in fn_graph_modified.nodes
153157
assert "first_node" in fn_graph_modified.nodes
154158
assert "second_node" in fn_graph_modified.nodes
159+
160+
161+
def test_sanitize_materializer_dependencies_happy():
162+
"""Tests that we return new objects & appropriately sanitize dependency types - converting them as necessary."""
163+
factory_1 = MaterializerFactory(
164+
"materializer_1",
165+
[MockDataSaver],
166+
dependencies=[
167+
tests.resources.test_default_args.A,
168+
tests.resources.test_default_args.B,
169+
"C",
170+
],
171+
result_builder=JoinBuilder(),
172+
storage_key="test_modify_function_graph_2",
173+
)
174+
s = {tests.resources.test_default_args.__name__}
175+
actual = factory_1.sanitize_dependencies(s)
176+
assert actual.id == factory_1.id
177+
assert actual.savers == factory_1.savers
178+
assert actual.result_builder == factory_1.result_builder
179+
assert actual.dependencies == ["A", "B", "C"]
180+
assert actual is not factory_1
181+
182+
183+
def test_sanitize_materializer_dependencies_error():
184+
"""Tests that we error when bad cases are encountered."""
185+
factory_1 = MaterializerFactory(
186+
"materializer_1",
187+
[MockDataSaver],
188+
dependencies=["B", tests.resources.cyclic_functions.A],
189+
result_builder=JoinBuilder(),
190+
storage_key="test_modify_function_graph_2",
191+
)
192+
with pytest.raises(ValueError):
193+
s = {tests.resources.test_default_args.__name__}
194+
factory_1.sanitize_dependencies(s)

tests/test_common.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pytest
2+
3+
import tests.resources.cyclic_functions
4+
import tests.resources.test_default_args
5+
from hamilton import common, driver
6+
7+
8+
class Object:
9+
"""Dummy class to test with."""
10+
11+
def __repr__(self):
12+
return "'object'"
13+
14+
15+
@pytest.mark.parametrize(
16+
"value_to_convert, module_set, expected_value, expected_error",
17+
[
18+
("a", {"amodule"}, "a", None),
19+
(
20+
tests.resources.test_default_args.A,
21+
{tests.resources.test_default_args.__name__},
22+
"A",
23+
None,
24+
),
25+
(driver.Variable("A", int), {"amodule"}, "A", None),
26+
(
27+
Object(),
28+
{"amodule"},
29+
None,
30+
"Materializer dependency 'object' is not a string, a function, or a driver.Variable.",
31+
),
32+
(
33+
tests.resources.cyclic_functions.A,
34+
{tests.resources.test_default_args.__name__},
35+
None,
36+
"Function tests.resources.cyclic_functions.A is a function not in a module given to the materializer. Valid choices are {'tests.resources.test_default_args'}.",
37+
),
38+
],
39+
)
40+
def test_convert_output_value(value_to_convert, module_set, expected_value, expected_error):
41+
actual_value, actual_error = common.convert_output_value(value_to_convert, module_set)
42+
assert actual_value == expected_value
43+
assert actual_error == expected_error
44+
45+
46+
def test_convert_output_values_happy():
47+
"""Tests that we loop as expected without issue"""
48+
actual = common.convert_output_values(
49+
[tests.resources.test_default_args.A, "B"], {tests.resources.test_default_args.__name__}
50+
)
51+
assert actual == ["A", "B"]
52+
53+
54+
def test_convert_output_values_error():
55+
"""Tests that we error when bad cases are encountered."""
56+
with pytest.raises(ValueError):
57+
common.convert_output_values(
58+
[tests.resources.test_default_args.A, tests.resources.cyclic_functions.A],
59+
{tests.resources.test_default_args.__name__},
60+
)

0 commit comments

Comments
 (0)