Skip to content

Commit cbadd60

Browse files
authored
Use typing.get_type_hints for input/output and resource/config type inference (#14707)
## Summary & Motivation Fixes #14571 Currently, `Parameter.annotation` is used when inferring the dagster type from an op compute function signature. This causes an error if either: (a) a string is used directly by the user for the parameter annotation; (b) `from __future__ import annotations` is used. This is because both cases cause the value of `Parameter.annotation` to be a string, and our dagster type resolution logic does not handle strings. This PR replaces direct access to `Parameter.annotation` with `typing.get_type_hints`, which handles evaluation of a string annotation into an actual class which can be handled by our dagster type resolution logic. Note that this only works if the referents of a string annotation is module-level scope-- it will fail for a local scope. Still this is an upgrade and will fix the problem in most cases. It will: - Allows users to use string annotations or `from __future__ import annotations` together with dagster type signature inference - Allows us to use `from __future__ import annotations` in our own code, which is necessary to allow Ruff to auto-sort typing-only imports-- see #14675. Not clear whether this is a good idea. ## How I Tested These Changes New unit test using string annotations.
1 parent 0c37101 commit cbadd60

File tree

7 files changed

+135
-31
lines changed

7 files changed

+135
-31
lines changed

python_modules/dagster/dagster/_core/decorator_utils.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
1+
import functools
2+
import re
13
import textwrap
24
from inspect import Parameter, signature
3-
from typing import Any, Callable, Optional, Sequence, Set, TypeVar, Union
4-
5-
from typing_extensions import Concatenate, ParamSpec, TypeGuard
5+
from typing import (
6+
Any,
7+
Callable,
8+
Mapping,
9+
Optional,
10+
Sequence,
11+
Set,
12+
TypeVar,
13+
Union,
14+
)
15+
16+
from typing_extensions import (
17+
Concatenate,
18+
ParamSpec,
19+
TypeGuard,
20+
get_type_hints as typing_get_type_hints,
21+
)
22+
23+
from dagster._core.errors import DagsterInvalidDefinitionError
624

725
R = TypeVar("R")
826
T = TypeVar("T")
@@ -36,6 +54,26 @@ def get_function_params(fn: Callable[..., Any]) -> Sequence[Parameter]:
3654
return list(signature(fn).parameters.values())
3755

3856

57+
def get_type_hints(fn: Callable) -> Mapping[str, Any]:
58+
target = fn.func if isinstance(fn, functools.partial) else fn
59+
60+
try:
61+
return typing_get_type_hints(target, include_extras=True)
62+
except NameError as e:
63+
match = re.search(r"'(\w+)'", str(e))
64+
assert match
65+
annotation = match[1]
66+
raise DagsterInvalidDefinitionError(
67+
f'Failed to resolve type annotation "{annotation}" in function {target.__name__}. This'
68+
" can occur when the parameter has a string annotation that references either: (1) a"
69+
" type defined in a local scope (2) a type that is defined or imported in an `if"
70+
" TYPE_CHECKING` block. Note that if you are including `from __future__ import"
71+
" annotations`, all annotations in that module are stored as strings. Suggested"
72+
" solutions include: (1) convert the annotation to a non-string annotation; (2) move"
73+
" the type referenced by the annotation out of local scope or a `TYPE_CHECKING` block."
74+
)
75+
76+
3977
def validate_expected_params(
4078
params: Sequence[Parameter], expected_params: Sequence[str]
4179
) -> Optional[str]:

python_modules/dagster/dagster/_core/definitions/decorators/sensor_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _wrapped_fn(*args, **kwargs) -> Any:
187187

188188
# Preserve any resource arguments from the underlying function, for when we inspect the
189189
# wrapped function later on
190-
_wrapped_fn.__signature__ = inspect.signature(fn)
190+
_wrapped_fn = update_wrapper(_wrapped_fn, wrapped=fn)
191191

192192
return AssetSensorDefinition(
193193
name=sensor_name,

python_modules/dagster/dagster/_core/definitions/inference.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
from inspect import Parameter, Signature, isgeneratorfunction, signature
2-
from typing import Any, Callable, Mapping, NamedTuple, Optional, Sequence
3-
2+
from typing import (
3+
Any,
4+
Callable,
5+
Mapping,
6+
NamedTuple,
7+
Optional,
8+
Sequence,
9+
)
10+
11+
from dagster._core.decorator_utils import get_type_hints
412
from dagster._seven import is_module_available
513

614
from .utils import NoValueSentinel
@@ -55,11 +63,12 @@ def _infer_output_description_from_docstring(fn: Callable) -> Optional[str]:
5563

5664

5765
def infer_output_props(fn: Callable) -> InferredOutputProps:
58-
sig = signature(fn)
59-
60-
annotation = Parameter.empty
61-
if not isgeneratorfunction(fn):
62-
annotation = sig.return_annotation
66+
type_hints = get_type_hints(fn)
67+
annotation = (
68+
type_hints["return"]
69+
if not isgeneratorfunction(fn) and "return" in type_hints
70+
else Parameter.empty
71+
)
6372

6473
return InferredOutputProps(
6574
annotation=annotation,
@@ -74,6 +83,7 @@ def has_explicit_return_type(fn: Callable) -> bool:
7483

7584
def _infer_inputs_from_params(
7685
params: Sequence[Parameter],
86+
type_hints: Mapping[str, object],
7787
descriptions: Optional[Mapping[str, Optional[str]]] = None,
7888
) -> Sequence[InferredInputProps]:
7989
_descriptions: Mapping[str, Optional[str]] = descriptions or {}
@@ -82,14 +92,14 @@ def _infer_inputs_from_params(
8292
if param.default is not Parameter.empty:
8393
input_def = InferredInputProps(
8494
param.name,
85-
param.annotation,
95+
type_hints.get(param.name, param.annotation),
8696
default_value=param.default,
8797
description=_descriptions.get(param.name),
8898
)
8999
else:
90100
input_def = InferredInputProps(
91101
param.name,
92-
param.annotation,
102+
type_hints.get(param.name, param.annotation),
93103
description=_descriptions.get(param.name),
94104
)
95105

@@ -101,7 +111,8 @@ def _infer_inputs_from_params(
101111
def infer_input_props(fn: Callable, context_arg_provided: bool) -> Sequence[InferredInputProps]:
102112
sig = signature(fn)
103113
params = list(sig.parameters.values())
114+
type_hints = get_type_hints(fn)
104115
descriptions = _infer_input_description_from_docstring(fn)
105116
params_to_infer = params[1:] if context_arg_provided else params
106-
defs = _infer_inputs_from_params(params_to_infer, descriptions=descriptions)
117+
defs = _infer_inputs_from_params(params_to_infer, type_hints, descriptions=descriptions)
107118
return defs

python_modules/dagster/dagster/_core/definitions/resource_annotation.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
from inspect import Parameter
2-
from typing import Sequence, TypeVar
2+
from typing import Any, Optional, Sequence, Type, TypeVar
33

44
from typing_extensions import Annotated
55

6-
from dagster._core.decorator_utils import get_function_params
6+
from dagster._core.decorator_utils import get_function_params, get_type_hints
77
from dagster._core.definitions.resource_definition import ResourceDefinition
88

99

1010
def get_resource_args(fn) -> Sequence[Parameter]:
11-
return [param for param in get_function_params(fn) if _is_resource_annotated(param)]
11+
type_annotations = get_type_hints(fn)
12+
return [
13+
param
14+
for param in get_function_params(fn)
15+
if _is_resource_annotation(type_annotations.get(param.name))
16+
]
1217

1318

1419
RESOURCE_PARAM_METADATA = "resource_param"
1520

1621

17-
def _is_resource_annotated(param: Parameter) -> bool:
22+
def _is_resource_annotation(annotation: Optional[Type[Any]]) -> bool:
1823
from dagster._config.pythonic_config import ConfigurableResourceFactory
1924

2025
extends_resource_definition = False
2126
try:
22-
extends_resource_definition = isinstance(param.annotation, type) and issubclass(
23-
param.annotation, (ResourceDefinition, ConfigurableResourceFactory)
27+
extends_resource_definition = isinstance(annotation, type) and issubclass(
28+
annotation, (ResourceDefinition, ConfigurableResourceFactory)
2429
)
2530
except TypeError:
2631
# Using builtin Python types in python 3.9+ will raise a TypeError when using issubclass
@@ -29,8 +34,8 @@ def _is_resource_annotated(param: Parameter) -> bool:
2934
pass
3035

3136
return (extends_resource_definition) or (
32-
hasattr(param.annotation, "__metadata__")
33-
and getattr(param.annotation, "__metadata__") == (RESOURCE_PARAM_METADATA,)
37+
hasattr(annotation, "__metadata__")
38+
and getattr(annotation, "__metadata__") == (RESOURCE_PARAM_METADATA,)
3439
)
3540

3641

python_modules/dagster/dagster/_core/definitions/source_asset.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from __future__ import annotations
2-
31
import warnings
42
from typing import (
5-
TYPE_CHECKING,
63
AbstractSet,
74
Any,
85
Callable,
@@ -54,12 +51,6 @@
5451
from dagster._utils.backcompat import ExperimentalWarning, experimental_arg_warning
5552
from dagster._utils.merger import merge_dicts
5653

57-
if TYPE_CHECKING:
58-
from dagster._core.execution.context.compute import (
59-
OpExecutionContext,
60-
)
61-
62-
6354
# Going with this catch-all for the time-being to permit pythonic resources
6455
SourceAssetObserveFunction: TypeAlias = Callable[..., Any]
6556

@@ -193,6 +184,9 @@ def _get_op_def_compute_fn(self, observe_fn: SourceAssetObserveFunction):
193184
DecoratedOpFunction,
194185
is_context_provided,
195186
)
187+
from dagster._core.execution.context.compute import (
188+
OpExecutionContext,
189+
)
196190

197191
observe_fn_has_context = is_context_provided(get_function_params(observe_fn))
198192

python_modules/dagster/dagster_tests/core_tests/resource_tests/pythonic_resources/test_binding.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,3 +585,28 @@ def my_asset(context, my_resource: MyResourceWithDefault):
585585
)
586586

587587
assert executed["yes"]
588+
589+
590+
class MyModuleLevelResource(ConfigurableResource):
591+
str_field: str
592+
593+
594+
# Note that an explicit string annotation has the same effect as defining a resource in a module
595+
# using `from __future__ import annotations`. This test will only work against a module-scoped
596+
# resource-- this is a hard limitation of string annotations in Python as of 2023-07-06 and Python
597+
# 3.11.
598+
def test_bind_with_string_annotation():
599+
@asset
600+
def my_asset(context, my_resource: "MyModuleLevelResource"):
601+
return my_resource.str_field
602+
603+
str_field_value = "foo"
604+
605+
defs = Definitions(
606+
[my_asset], resources={"my_resource": MyModuleLevelResource(str_field=str_field_value)}
607+
)
608+
609+
assert (
610+
defs.get_implicit_global_asset_job_def().execute_in_process().output_for_node("my_asset")
611+
== str_field_value
612+
)

python_modules/dagster/dagster_tests/general_tests/py3_tests/test_inference.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,37 @@ def add_one(num: int) -> int:
111111
assert add_one.output_defs[0].dagster_type.unique_name == "Int"
112112

113113

114+
def test_string_typed_input_and_output():
115+
@op
116+
def add_one(_context, num: "Optional[int]") -> "int":
117+
return num + 1 if num else 1
118+
119+
assert add_one
120+
assert len(add_one.input_defs) == 1
121+
assert add_one.input_defs[0].name == "num"
122+
assert add_one.input_defs[0].dagster_type.display_name == "Int?"
123+
124+
assert len(add_one.output_defs) == 1
125+
assert add_one.output_defs[0].dagster_type.unique_name == "Int"
126+
127+
128+
def _make_foo():
129+
class Foo:
130+
pass
131+
132+
def foo(x: "Foo") -> "Foo":
133+
return x
134+
135+
return foo
136+
137+
138+
def test_invalid_string_typed_input():
139+
with pytest.raises(
140+
DagsterInvalidDefinitionError, match='Failed to resolve type annotation "Foo"'
141+
):
142+
op(_make_foo())
143+
144+
114145
def test_wrapped_input_and_output_lambda():
115146
@op
116147
def add_one(nums: List[int]) -> Optional[List[int]]:

0 commit comments

Comments
 (0)