From 58bdbcc657bfceb03e6d1c578380b78ba9d9e1e8 Mon Sep 17 00:00:00 2001 From: Owen Kephart Date: Thu, 16 Jan 2025 08:37:15 -0800 Subject: [PATCH] [components] Update resolution deferral logic --- .../defining-resolvable-field.py | 2 +- .../dagster_components/core/component.py | 2 +- .../dagster_components/core/schema/base.py | 9 ++--- .../core/schema/metadata.py | 36 ++++++++++--------- .../core/schema/resolver.py | 36 ++++++++----------- .../lib/dbt_project/component.py | 2 +- .../sling_replication_collection/component.py | 2 +- .../lib/test/complex_schema_asset.py | 2 +- .../rendering_tests/test_schema_resolution.py | 34 +++++++++--------- 9 files changed, 59 insertions(+), 66 deletions(-) diff --git a/examples/docs_beta_snippets/docs_beta_snippets/guides/components/shell-script-component/defining-resolvable-field.py b/examples/docs_beta_snippets/docs_beta_snippets/guides/components/shell-script-component/defining-resolvable-field.py index 77b57213f0a29..be7ce9853fde7 100644 --- a/examples/docs_beta_snippets/docs_beta_snippets/guides/components/shell-script-component/defining-resolvable-field.py +++ b/examples/docs_beta_snippets/docs_beta_snippets/guides/components/shell-script-component/defining-resolvable-field.py @@ -14,7 +14,7 @@ class ShellScriptSchema(ComponentSchemaBaseModel): script_runner: Annotated[ str, ResolvableFieldInfo( - output_type=ScriptRunner, additional_scope={"get_script_runner"} + output_type=ScriptRunner, required_scope={"get_script_runner"} ), ] # highlight-end diff --git a/python_modules/libraries/dagster-components/dagster_components/core/component.py b/python_modules/libraries/dagster-components/dagster_components/core/component.py index 8fe784c6a1ddc..aa48f9ed2e09d 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/component.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/component.py @@ -211,7 +211,7 @@ def get_registered_component_types_in_module(module: ModuleType) -> Iterable[typ yield component -T = TypeVar("T") +T = TypeVar("T", bound=BaseModel) @dataclass diff --git a/python_modules/libraries/dagster-components/dagster_components/core/schema/base.py b/python_modules/libraries/dagster-components/dagster_components/core/schema/base.py index 056c90101195f..b1ad0af86b223 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/schema/base.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/schema/base.py @@ -3,19 +3,14 @@ from pydantic import BaseModel, ConfigDict, TypeAdapter -from dagster_components.core.schema.metadata import ( - JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY, - get_resolution_metadata, -) +from dagster_components.core.schema.metadata import get_resolution_metadata from dagster_components.core.schema.resolver import TemplatedValueResolver class ComponentSchemaBaseModel(BaseModel): """Base class for models that are part of a component schema.""" - model_config = ConfigDict( - json_schema_extra={JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY: True}, extra="forbid" - ) + model_config = ConfigDict(extra="forbid") def resolve_properties(self, value_resolver: TemplatedValueResolver) -> Mapping[str, Any]: """Returns a dictionary of resolved properties for this class.""" diff --git a/python_modules/libraries/dagster-components/dagster_components/core/schema/metadata.py b/python_modules/libraries/dagster-components/dagster_components/core/schema/metadata.py index fb64dcc8761c2..cb61a8fb9dde0 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/schema/metadata.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/schema/metadata.py @@ -8,7 +8,7 @@ REF_BASE = "#/$defs/" REF_TEMPLATE = f"{REF_BASE}{{model}}" JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY = "dagster_defer_rendering" -JSON_SCHEMA_EXTRA_AVAILABLE_SCOPE_KEY = "dagster_available_scope" +JSON_SCHEMA_EXTRA_REQUIRED_SCOPE_KEY = "dagster_required_scope" @dataclass @@ -34,7 +34,7 @@ def __init__( *, output_type: Optional[type] = None, post_process_fn: Optional[Callable[[Any], Any]] = None, - additional_scope: Optional[Set[str]] = None, + required_scope: Optional[Set[str]] = None, ): self.resolution_metadata = ( ResolutionMetadata(output_type=output_type, post_process=post_process_fn) @@ -43,8 +43,11 @@ def __init__( ) super().__init__( json_schema_extra={ - JSON_SCHEMA_EXTRA_AVAILABLE_SCOPE_KEY: list(additional_scope or []), - JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY: True, + JSON_SCHEMA_EXTRA_REQUIRED_SCOPE_KEY: list(required_scope or []), + # defer resolution if the output type will change + **( + {JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY: True} if output_type is not None else {} + ), }, ) @@ -98,33 +101,34 @@ def _subschemas_on_path( yield from _subschemas_on_path(rest, json_schema, inner) -def _should_defer_render(subschema: Mapping[str, Any]) -> bool: +def _get_should_defer_resolve(subschema: Mapping[str, Any]) -> bool: raw = check.opt_inst(subschema.get(JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY), bool) return raw or False -def _get_available_scope(subschema: Mapping[str, Any]) -> Set[str]: - raw = check.opt_inst(subschema.get(JSON_SCHEMA_EXTRA_AVAILABLE_SCOPE_KEY), list) +def _get_additional_required_scope(subschema: Mapping[str, Any]) -> Set[str]: + raw = check.opt_inst(subschema.get(JSON_SCHEMA_EXTRA_REQUIRED_SCOPE_KEY), list) return set(raw) if raw else set() def allow_resolve( valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any] ) -> bool: - """Given a valpath and the json schema of a given target type, determines if there is a rendering scope - required to render the value at the given path. + """Given a valpath and the json schema of a given target type, determines if this value can be + resolved eagerly. This can only happen if the output type of the resolved value is unchanged, + and there is no additional scope required for resolution. """ for subschema in _subschemas_on_path(valpath, json_schema, subschema): - if _should_defer_render(subschema): + if _get_should_defer_resolve(subschema) or _get_additional_required_scope(subschema): return False return True -def get_available_scope( - valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any] +def get_required_scope( + valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any] ) -> Set[str]: """Given a valpath and the json schema of a given target type, determines the available rendering scope.""" - available_scope = set() - for subschema in _subschemas_on_path(valpath, json_schema, subschema): - available_scope |= _get_available_scope(subschema) - return available_scope + required_scope = set() + for subschema in _subschemas_on_path(valpath, json_schema, json_schema): + required_scope |= _get_additional_required_scope(subschema) + return required_scope diff --git a/python_modules/libraries/dagster-components/dagster_components/core/schema/resolver.py b/python_modules/libraries/dagster-components/dagster_components/core/schema/resolver.py index 1bd9ed00d5007..8bfa67a7b2e2c 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/schema/resolver.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/schema/resolver.py @@ -26,9 +26,6 @@ def automation_condition_scope() -> Mapping[str, Any]: } -ShouldResolveFn = Callable[[Sequence[Union[str, int]]], bool] - - @record class TemplatedValueResolver: scope: Mapping[str, Any] @@ -50,23 +47,23 @@ def _resolve_obj( self, obj: Any, valpath: Optional[Sequence[Union[str, int]]], - should_render: Callable[[Sequence[Union[str, int]]], bool], + should_resolve: Callable[[Sequence[Union[str, int]]], bool], ) -> Any: - """Recursively resolves templated values in a nested object, based on the provided should_render function.""" - if valpath is not None and not should_render(valpath): + """Recursively resolves templated values in a nested object, based on the provided should_resolve function.""" + if valpath is not None and not should_resolve(valpath): return obj elif isinstance(obj, dict): - # render all values in the dict + # resolve all values in the dict return { k: self._resolve_obj( - v, [*valpath, k] if valpath is not None else None, should_render + v, [*valpath, k] if valpath is not None else None, should_resolve ) for k, v in obj.items() } elif isinstance(obj, list): - # render all values in the list + # resolve all values in the list return [ - self._resolve_obj(v, [*valpath, i] if valpath is not None else None, should_render) + self._resolve_obj(v, [*valpath, i] if valpath is not None else None, should_resolve) for i, v in enumerate(obj) ] else: @@ -76,15 +73,12 @@ def resolve_obj(self, val: Any) -> Any: """Recursively resolves templated values in a nested object.""" return self._resolve_obj(val, None, lambda _: True) - def resolve_params(self, val: T, target_type: type) -> T: - """Given a raw params value, preprocesses it by rendering any templated values that are not marked as deferred in the target_type's json schema.""" - json_schema = ( - target_type.model_json_schema() if issubclass(target_type, BaseModel) else None + def resolve_params(self, val: T, target_type: type[BaseModel]) -> T: + """Given a raw params value, preprocesses it by resolving any templated values that are not marked + as deferred in the target_type's json schema. + """ + json_schema = target_type.model_json_schema() + should_resolve = functools.partial( + allow_resolve, json_schema=json_schema, subschema=json_schema ) - if json_schema is None: - should_render = lambda _: True - else: - should_render = functools.partial( - allow_resolve, json_schema=json_schema, subschema=json_schema - ) - return self._resolve_obj(val, [], should_render=should_render) + return self._resolve_obj(val, [], should_resolve=should_resolve) diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project/component.py b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project/component.py index ecdff88746179..ed62c0016a972 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project/component.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project/component.py @@ -30,7 +30,7 @@ class DbtProjectParams(BaseModel): dbt: DbtCliResource op: Optional[OpSpecBaseModel] = None asset_attributes: Annotated[ - Optional[AssetAttributesModel], ResolvableFieldInfo(additional_scope={"node"}) + Optional[AssetAttributesModel], ResolvableFieldInfo(required_scope={"node"}) ] = None transforms: Optional[Sequence[AssetSpecTransformModel]] = None diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication_collection/component.py b/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication_collection/component.py index ceec349a01ece..612cd9fbe82e2 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication_collection/component.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication_collection/component.py @@ -27,7 +27,7 @@ class SlingReplicationParams(BaseModel): path: str op: Optional[OpSpecBaseModel] = None asset_attributes: Annotated[ - Optional[AssetAttributesModel], ResolvableFieldInfo(additional_scope={"stream_definition"}) + Optional[AssetAttributesModel], ResolvableFieldInfo(required_scope={"stream_definition"}) ] = None diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/test/complex_schema_asset.py b/python_modules/libraries/dagster-components/dagster_components/lib/test/complex_schema_asset.py index 67cbc5f9499ab..a194dffb968ef 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/test/complex_schema_asset.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/test/complex_schema_asset.py @@ -21,7 +21,7 @@ class ComplexAssetParams(BaseModel): value: str op: Optional[OpSpecBaseModel] = None asset_attributes: Annotated[ - Optional[AssetAttributesModel], ResolvableFieldInfo(additional_scope={"node"}) + Optional[AssetAttributesModel], ResolvableFieldInfo(required_scope={"node"}) ] = None asset_transforms: Optional[Sequence[AssetSpecTransformModel]] = None diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_schema_resolution.py b/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_schema_resolution.py index ee71751c847b6..b6339a4c6c587 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_schema_resolution.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_schema_resolution.py @@ -3,19 +3,19 @@ import pytest from dagster_components import ComponentSchemaBaseModel, ResolvableFieldInfo, TemplatedValueResolver -from dagster_components.core.schema.metadata import allow_resolve, get_available_scope +from dagster_components.core.schema.metadata import allow_resolve, get_required_scope from pydantic import BaseModel, Field, TypeAdapter, ValidationError class InnerRendered(ComponentSchemaBaseModel): - a: Optional[str] = None + a: Annotated[Optional[str], ResolvableFieldInfo(required_scope={"deferred"})] = None class Container(BaseModel): a: str inner: InnerRendered - inner_scoped: Annotated[InnerRendered, ResolvableFieldInfo(additional_scope={"c", "d"})] = ( - Field(default_factory=InnerRendered) + inner_scoped: Annotated[InnerRendered, ResolvableFieldInfo(required_scope={"c", "d"})] = Field( + default_factory=InnerRendered ) @@ -25,36 +25,38 @@ class Outer(BaseModel): container: Container container_optional: Optional[Container] = None container_optional_scoped: Annotated[ - Optional[Container], ResolvableFieldInfo(additional_scope={"a", "b"}) + Optional[Container], ResolvableFieldInfo(required_scope={"a", "b"}) ] = None inner_seq: Sequence[InnerRendered] inner_optional: Optional[InnerRendered] = None inner_optional_seq: Optional[Sequence[InnerRendered]] = None + transformed: Annotated[Optional[str], ResolvableFieldInfo(output_type=Optional[int])] = None @pytest.mark.parametrize( "path,expected", [ (["a"], True), - (["inner"], False), + (["inner"], True), (["inner", "a"], False), (["container", "a"], True), - (["container", "inner"], False), + (["container", "inner"], True), (["container", "inner", "a"], False), (["container_optional", "a"], True), - (["container_optional", "inner"], False), + (["container_optional", "inner"], True), (["container_optional", "inner", "a"], False), (["container_optional_scoped"], False), (["container_optional_scoped", "inner", "a"], False), (["container_optional_scoped", "inner_scoped", "a"], False), (["inner_seq"], True), - (["inner_seq", 0], False), + (["inner_seq", 0], True), (["inner_seq", 0, "a"], False), (["inner_optional"], True), (["inner_optional", "a"], False), (["inner_optional_seq"], True), - (["inner_optional_seq", 0], False), + (["inner_optional_seq", 0], True), (["inner_optional_seq", 0, "a"], False), + (["transformed"], False), ], ) def test_allow_render(path, expected: bool) -> None: @@ -65,19 +67,17 @@ def test_allow_render(path, expected: bool) -> None: "path,expected", [ (["a"], set()), - (["inner", "a"], set()), - (["container_optional", "inner", "a"], set()), + (["inner", "a"], {"deferred"}), + (["container_optional", "inner", "a"], {"deferred"}), (["inner_seq"], set()), (["container_optional_scoped"], {"a", "b"}), (["container_optional_scoped", "inner"], {"a", "b"}), (["container_optional_scoped", "inner_scoped"], {"a", "b", "c", "d"}), - (["container_optional_scoped", "inner_scoped", "a"], {"a", "b", "c", "d"}), + (["container_optional_scoped", "inner_scoped", "a"], {"a", "b", "c", "d", "deferred"}), ], ) -def test_get_available_scope(path, expected: Set[str]) -> None: - assert ( - get_available_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected - ) +def test_get_required_scope(path, expected: Set[str]) -> None: + assert get_required_scope(path, Outer.model_json_schema()) == expected def test_render() -> None: