From 67dc5d07c66af08c09d8a0a65b529d000056fb5a Mon Sep 17 00:00:00 2001 From: Owen Kephart Date: Fri, 20 Dec 2024 09:37:28 -0500 Subject: [PATCH] [components] Templating for asset_attributes --- .../dagster_components/core/component.py | 6 +- .../core/component_rendering.py | 78 +++++++++++-------- .../dagster_components/core/dsl_schema.py | 47 ++++++----- .../dagster_components/lib/dbt_project.py | 2 +- .../lib/sling_replication.py | 2 +- .../test_component_rendering.py | 9 ++- .../unit_tests/test_spec_processing.py | 33 +++++++- 7 files changed, 114 insertions(+), 63 deletions(-) diff --git a/python_modules/libraries/dagster-components/dagster_components/core/component.py b/python_modules/libraries/dagster-components/dagster_components/core/component.py index 2c68dbfb00e16..51dd0d84ef3b2 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/component.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/component.py @@ -31,7 +31,7 @@ from pydantic import TypeAdapter from typing_extensions import Self -from dagster_components.core.component_rendering import TemplatedValueResolver, preprocess_value +from dagster_components.core.component_rendering import TemplatedValueResolver class ComponentDeclNode: ... @@ -254,8 +254,8 @@ def _raw_params(self) -> Optional[Mapping[str, Any]]: def load_params(self, params_schema: Type[T]) -> T: with pushd(str(self.path)): - preprocessed_params = preprocess_value( - self.templated_value_resolver, self._raw_params(), params_schema + preprocessed_params = self.templated_value_resolver.resolve_params( + self._raw_params(), params_schema ) return TypeAdapter(params_schema).validate_python(preprocessed_params) diff --git a/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py b/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py index b766edb09186d..6742debd01315 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/component_rendering.py @@ -1,6 +1,7 @@ +import functools import json import os -from typing import AbstractSet, Any, Mapping, Optional, Sequence, Type, TypeVar, Union +from typing import AbstractSet, Any, Callable, Mapping, Optional, Sequence, Type, TypeVar, Union import dagster._check as check from dagster._record import record @@ -40,6 +41,9 @@ def _env(key: str) -> Optional[str]: return os.environ.get(key) +ShouldRenderFn = Callable[[Sequence[Union[str, int]]], bool] + + @record class TemplatedValueResolver: context: Mapping[str, Any] @@ -51,11 +55,49 @@ def default() -> "TemplatedValueResolver": def with_context(self, **additional_context) -> "TemplatedValueResolver": return TemplatedValueResolver(context={**self.context, **additional_context}) - def resolve(self, val: Any) -> Any: + def _resolve_value(self, val: Any) -> Any: return NativeTemplate(val).render(**self.context) if isinstance(val, str) else val + def _resolve( + self, + val: Any, + valpath: Optional[Sequence[Union[str, int]]], + should_render: Callable[[Sequence[Union[str, int]]], bool], + ) -> Any: + if valpath is not None and not should_render(valpath): + return val + elif isinstance(val, dict): + return { + k: self._resolve(v, [*valpath, k] if valpath is not None else None, should_render) + for k, v in val.items() + } + elif isinstance(val, list): + return [ + self._resolve(v, [*valpath, i] if valpath is not None else None, should_render) + for i, v in enumerate(val) + ] + else: + return self._resolve_value(val) -def _should_render( + def resolve(self, val: Any) -> Any: + """Given a raw value, preprocesses it by rendering any templated values.""" + return self._resolve(val, None, lambda _: True) + + def resolve_params(self, val: T, target_type: Type) -> T: + """Given a raw value, preprocesses it by rendering any templated values that are not marked as deferred in the target_type's json schema.""" + json_schema = ( + target_type.model_json_schema() if issubclass(target_type, BaseModel) else None + ) + if json_schema is None: + should_render = lambda _: True + else: + should_render = functools.partial( + has_required_scope, json_schema=json_schema, subschema=json_schema + ) + return self._resolve(val, [], should_render=should_render) + + +def has_required_scope( valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any] ) -> bool: # List[ComplexType] (e.g.) will contain a reference to the complex type schema in the @@ -70,7 +112,7 @@ def _should_render( # Optional[ComplexType] (e.g.) will contain multiple schemas in the "anyOf" field if "anyOf" in subschema: - return all(_should_render(valpath, json_schema, inner) for inner in subschema["anyOf"]) + return all(has_required_scope(valpath, json_schema, inner) for inner in subschema["anyOf"]) el = valpath[0] if isinstance(el, str): @@ -89,30 +131,4 @@ def _should_render( return subschema.get("additionalProperties", True) _, *rest = valpath - return _should_render(rest, json_schema, inner) - - -def _render_values( - value_resolver: TemplatedValueResolver, - val: Any, - valpath: Sequence[Union[str, int]], - json_schema: Optional[Mapping[str, Any]], -) -> Any: - if json_schema and not _should_render(valpath, json_schema, json_schema): - return val - elif isinstance(val, dict): - return { - k: _render_values(value_resolver, v, [*valpath, k], json_schema) for k, v in val.items() - } - elif isinstance(val, list): - return [ - _render_values(value_resolver, v, [*valpath, i], json_schema) for i, v in enumerate(val) - ] - else: - return value_resolver.resolve(val) - - -def preprocess_value(renderer: TemplatedValueResolver, val: T, target_type: Type) -> T: - """Given a raw value, preprocesses it by rendering any templated values that are not marked as deferred in the target_type's json schema.""" - json_schema = target_type.model_json_schema() if issubclass(target_type, BaseModel) else None - return _render_values(renderer, val, [], json_schema) + return has_required_scope(rest, json_schema, inner) diff --git a/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py b/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py index 39af228db176f..d2a2f1e05c26f 100644 --- a/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py +++ b/python_modules/libraries/dagster-components/dagster_components/core/dsl_schema.py @@ -1,6 +1,7 @@ -from abc import ABC, abstractmethod -from typing import Annotated, Any, Dict, Literal, Mapping, Optional, Sequence, Union +from abc import ABC +from typing import AbstractSet, Annotated, Any, Dict, Literal, Mapping, Optional, Sequence, Union +from dagster._core.definitions.asset_key import AssetKey from dagster._core.definitions.asset_selection import AssetSelection from dagster._core.definitions.asset_spec import AssetSpec, map_asset_specs from dagster._core.definitions.assets import AssetsDefinition @@ -11,6 +12,8 @@ from dagster._record import replace from pydantic import BaseModel, Field +from dagster_components.core.component_rendering import RenderingScope, TemplatedValueResolver + class OpSpecBaseModel(BaseModel): name: Optional[str] = None @@ -33,26 +36,31 @@ class AssetSpecProcessor(ABC, BaseModel): tags: Optional[Mapping[str, str]] = None automation_condition: Optional[AutomationConditionModel] = None - def _attributes(self) -> Mapping[str, Any]: - return { - **self.model_dump(exclude={"target", "operation"}, exclude_unset=True), - **{ - "automation_condition": self.automation_condition.to_automation_condition() - if self.automation_condition - else None - }, - } + def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: ... + + def apply_to_spec( + self, + spec: AssetSpec, + value_resolver: TemplatedValueResolver, + target_keys: AbstractSet[AssetKey], + ) -> AssetSpec: + if spec.key not in target_keys: + return spec - @abstractmethod - def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec: ... + # add the original spec to the context and resolve values + attributes = value_resolver.with_context(asset=spec).resolve( + self.model_dump(exclude={"target", "operation"}, exclude_unset=True) + ) + return self._apply_to_spec(spec, attributes) - def apply(self, defs: Definitions) -> Definitions: + def apply(self, defs: Definitions, value_resolver: TemplatedValueResolver) -> Definitions: target_selection = AssetSelection.from_string(self.target, include_sources=True) target_keys = target_selection.resolve(defs.get_asset_graph()) mappable = [d for d in defs.assets or [] if isinstance(d, (AssetsDefinition, AssetSpec))] mapped_assets = map_asset_specs( - lambda spec: self._apply_to_spec(spec) if spec.key in target_keys else spec, mappable + lambda spec: self.apply_to_spec(spec, value_resolver, target_keys), + mappable, ) assets = [ @@ -66,8 +74,7 @@ class MergeAttributes(AssetSpecProcessor): # default operation is "merge" operation: Literal["merge"] = "merge" - def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec: - attributes = self._attributes() + def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: mergeable_attributes = {"metadata", "tags"} merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes} replace_attributes = {k: v for k, v in attributes.items() if k not in mergeable_attributes} @@ -78,13 +85,13 @@ class ReplaceAttributes(AssetSpecProcessor): # operation must be set explicitly operation: Literal["replace"] - def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec: - return spec.replace_attributes(**self._attributes()) + def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: + return spec.replace_attributes(**attributes) AssetAttributes = Sequence[ Annotated[ Union[MergeAttributes, ReplaceAttributes], - Field(union_mode="left_to_right"), + RenderingScope(Field(union_mode="left_to_right"), required_scope={"asset"}), ] ] diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py index 4222ef7ab9f89..db107e9a95d0a 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/dbt_project.py @@ -120,7 +120,7 @@ def _fn(context: AssetExecutionContext): defs = Definitions(assets=[_fn]) for transform in self.asset_processors: - defs = transform.apply(defs) + defs = transform.apply(defs, context.templated_value_resolver) return defs @classmethod diff --git a/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication.py b/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication.py index 63c466930deb9..ea131a9e8f71b 100644 --- a/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication.py +++ b/python_modules/libraries/dagster-components/dagster_components/lib/sling_replication.py @@ -60,7 +60,7 @@ def _fn(context: AssetExecutionContext, sling: SlingResource): defs = Definitions(assets=[_fn], resources={"sling": self.resource}) for transform in self.asset_processors: - defs = transform.apply(defs) + defs = transform.apply(defs, context.templated_value_resolver) return defs @classmethod diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py b/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py index e93fbd4a9dbe6..4bf8892908c6e 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/rendering_tests/test_component_rendering.py @@ -4,8 +4,7 @@ from dagster_components.core.component_rendering import ( RenderingScope, TemplatedValueResolver, - _should_render, - preprocess_value, + has_required_scope, ) from pydantic import BaseModel, Field, TypeAdapter @@ -44,7 +43,9 @@ class Outer(BaseModel): ], ) def test_should_render(path, expected: bool) -> None: - assert _should_render(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected + assert ( + has_required_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected + ) def test_render() -> None: @@ -61,7 +62,7 @@ def test_render() -> None: } renderer = TemplatedValueResolver(context={"foo_val": "foo", "bar_val": "bar"}) - rendered_data = preprocess_value(renderer, data, Outer) + rendered_data = renderer.resolve_params(data, Outer) assert rendered_data == { "a": "foo", diff --git a/python_modules/libraries/dagster-components/dagster_components_tests/unit_tests/test_spec_processing.py b/python_modules/libraries/dagster-components/dagster_components_tests/unit_tests/test_spec_processing.py index 6af75bf8544e4..d7384d5bf53c9 100644 --- a/python_modules/libraries/dagster-components/dagster_components_tests/unit_tests/test_spec_processing.py +++ b/python_modules/libraries/dagster-components/dagster_components_tests/unit_tests/test_spec_processing.py @@ -1,6 +1,11 @@ import pytest from dagster import AssetKey, AssetSpec, Definitions -from dagster_components.core.dsl_schema import AssetAttributes, MergeAttributes, ReplaceAttributes +from dagster_components.core.dsl_schema import ( + AssetAttributes, + MergeAttributes, + ReplaceAttributes, + TemplatedValueResolver, +) from pydantic import BaseModel, TypeAdapter @@ -20,7 +25,7 @@ class M(BaseModel): def test_replace_attributes() -> None: op = ReplaceAttributes(operation="replace", target="group:g2", tags={"newtag": "newval"}) - newdefs = op.apply(defs) + newdefs = op.apply(defs, TemplatedValueResolver.default()) asset_graph = newdefs.get_asset_graph() assert asset_graph.get(AssetKey("a")).tags == {} assert asset_graph.get(AssetKey("b")).tags == {"newtag": "newval"} @@ -30,13 +35,35 @@ def test_replace_attributes() -> None: def test_merge_attributes() -> None: op = MergeAttributes(operation="merge", target="group:g2", tags={"newtag": "newval"}) - newdefs = op.apply(defs) + newdefs = op.apply(defs, TemplatedValueResolver.default()) asset_graph = newdefs.get_asset_graph() assert asset_graph.get(AssetKey("a")).tags == {} assert asset_graph.get(AssetKey("b")).tags == {"newtag": "newval"} assert asset_graph.get(AssetKey("c")).tags == {"tag": "val", "newtag": "newval"} +def test_render_attributes_asset_context() -> None: + op = MergeAttributes(tags={"group_name_tag": "group__{{ asset.group_name }}"}) + + newdefs = op.apply(defs, TemplatedValueResolver.default().with_context(foo="theval")) + asset_graph = newdefs.get_asset_graph() + assert asset_graph.get(AssetKey("a")).tags == {"group_name_tag": "group__g1"} + assert asset_graph.get(AssetKey("b")).tags == {"group_name_tag": "group__g2"} + assert asset_graph.get(AssetKey("c")).tags == {"tag": "val", "group_name_tag": "group__g2"} + + +def test_render_attributes_custom_context() -> None: + op = ReplaceAttributes( + operation="replace", target="group:g2", tags={"a": "{{ foo }}", "b": "prefix_{{ foo }}"} + ) + + newdefs = op.apply(defs, TemplatedValueResolver.default().with_context(foo="theval")) + asset_graph = newdefs.get_asset_graph() + assert asset_graph.get(AssetKey("a")).tags == {} + assert asset_graph.get(AssetKey("b")).tags == {"a": "theval", "b": "prefix_theval"} + assert asset_graph.get(AssetKey("c")).tags == {"a": "theval", "b": "prefix_theval"} + + @pytest.mark.parametrize( "python,expected", [