Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[components] Templating for asset_attributes #26633

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pydantic import TypeAdapter
from typing_extensions import Self

from dagster_components.core.component_rendering import TemplatedValueResolver, preprocess_value
from dagster_components.core.component_rendering import TemplatedValueResolver


class ComponentDeclNode: ...
Expand Down Expand Up @@ -254,8 +254,8 @@ def _raw_params(self) -> Optional[Mapping[str, Any]]:

def load_params(self, params_schema: Type[T]) -> T:
with pushd(str(self.path)):
preprocessed_params = preprocess_value(
self.templated_value_resolver, self._raw_params(), params_schema
preprocessed_params = self.templated_value_resolver.resolve_params(
self._raw_params(), params_schema
)
return TypeAdapter(params_schema).validate_python(preprocessed_params)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import json
import os
from typing import AbstractSet, Any, Mapping, Optional, Sequence, Type, TypeVar, Union
from typing import AbstractSet, Any, Callable, Mapping, Optional, Sequence, Type, TypeVar, Union

import dagster._check as check
from dagster._record import record
Expand Down Expand Up @@ -40,6 +41,9 @@ def _env(key: str) -> Optional[str]:
return os.environ.get(key)


ShouldRenderFn = Callable[[Sequence[Union[str, int]]], bool]


@record
class TemplatedValueResolver:
context: Mapping[str, Any]
Expand All @@ -51,11 +55,49 @@ def default() -> "TemplatedValueResolver":
def with_context(self, **additional_context) -> "TemplatedValueResolver":
return TemplatedValueResolver(context={**self.context, **additional_context})

def resolve(self, val: Any) -> Any:
def _resolve_value(self, val: Any) -> Any:
return NativeTemplate(val).render(**self.context) if isinstance(val, str) else val

def _resolve(
self,
val: Any,
valpath: Optional[Sequence[Union[str, int]]],
should_render: Callable[[Sequence[Union[str, int]]], bool],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following this should_render stuff. Under what conditions will this be true versus false?

) -> Any:
if valpath is not None and not should_render(valpath):
return val
elif isinstance(val, dict):
return {
k: self._resolve(v, [*valpath, k] if valpath is not None else None, should_render)
for k, v in val.items()
}
elif isinstance(val, list):
return [
self._resolve(v, [*valpath, i] if valpath is not None else None, should_render)
for i, v in enumerate(val)
]
else:
return self._resolve_value(val)

def _should_render(
def resolve(self, val: Any) -> Any:
"""Given a raw value, preprocesses it by rendering any templated values."""
return self._resolve(val, None, lambda _: True)

def resolve_params(self, val: T, target_type: Type) -> T:
"""Given a raw value, preprocesses it by rendering any templated values that are not marked as deferred in the target_type's json schema."""
json_schema = (
target_type.model_json_schema() if issubclass(target_type, BaseModel) else None
)
if json_schema is None:
should_render = lambda _: True
else:
should_render = functools.partial(
has_required_scope, json_schema=json_schema, subschema=json_schema
)
return self._resolve(val, [], should_render=should_render)


def has_required_scope(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think part of my confusion is that this function name is a bit ambigious. Does it mean that there is a field that has the RenderingScope attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean that there is a field that has the RenderingScope attribute?

That's right yeah -- basically should_render is a function that reads the json schema and determines if anything along the path has some required scope

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_rendering_scope?

valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any]
) -> bool:
# List[ComplexType] (e.g.) will contain a reference to the complex type schema in the
Expand All @@ -70,7 +112,7 @@ def _should_render(

# Optional[ComplexType] (e.g.) will contain multiple schemas in the "anyOf" field
if "anyOf" in subschema:
return all(_should_render(valpath, json_schema, inner) for inner in subschema["anyOf"])
return all(has_required_scope(valpath, json_schema, inner) for inner in subschema["anyOf"])

el = valpath[0]
if isinstance(el, str):
Expand All @@ -89,30 +131,4 @@ def _should_render(
return subschema.get("additionalProperties", True)

_, *rest = valpath
return _should_render(rest, json_schema, inner)


def _render_values(
value_resolver: TemplatedValueResolver,
val: Any,
valpath: Sequence[Union[str, int]],
json_schema: Optional[Mapping[str, Any]],
) -> Any:
if json_schema and not _should_render(valpath, json_schema, json_schema):
return val
elif isinstance(val, dict):
return {
k: _render_values(value_resolver, v, [*valpath, k], json_schema) for k, v in val.items()
}
elif isinstance(val, list):
return [
_render_values(value_resolver, v, [*valpath, i], json_schema) for i, v in enumerate(val)
]
else:
return value_resolver.resolve(val)


def preprocess_value(renderer: TemplatedValueResolver, val: T, target_type: Type) -> T:
"""Given a raw value, preprocesses it by rendering any templated values that are not marked as deferred in the target_type's json schema."""
json_schema = target_type.model_json_schema() if issubclass(target_type, BaseModel) else None
return _render_values(renderer, val, [], json_schema)
return has_required_scope(rest, json_schema, inner)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Annotated, Any, Dict, Literal, Mapping, Optional, Sequence, Union
from abc import ABC
from typing import AbstractSet, Annotated, Any, Dict, Literal, Mapping, Optional, Sequence, Union

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.definitions.asset_spec import AssetSpec, map_asset_specs
from dagster._core.definitions.assets import AssetsDefinition
Expand All @@ -11,6 +12,8 @@
from dagster._record import replace
from pydantic import BaseModel, Field

from dagster_components.core.component_rendering import RenderingScope, TemplatedValueResolver


class OpSpecBaseModel(BaseModel):
name: Optional[str] = None
Expand All @@ -33,26 +36,31 @@ class AssetSpecProcessor(ABC, BaseModel):
tags: Optional[Mapping[str, str]] = None
automation_condition: Optional[AutomationConditionModel] = None

def _attributes(self) -> Mapping[str, Any]:
return {
**self.model_dump(exclude={"target", "operation"}, exclude_unset=True),
**{
"automation_condition": self.automation_condition.to_automation_condition()
if self.automation_condition
else None
},
}
def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: ...

def apply_to_spec(
self,
spec: AssetSpec,
value_resolver: TemplatedValueResolver,
target_keys: AbstractSet[AssetKey],
) -> AssetSpec:
if spec.key not in target_keys:
return spec

@abstractmethod
def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec: ...
# add the original spec to the context and resolve values
attributes = value_resolver.with_context(asset=spec).resolve(
self.model_dump(exclude={"target", "operation"}, exclude_unset=True)
)
return self._apply_to_spec(spec, attributes)

def apply(self, defs: Definitions) -> Definitions:
def apply(self, defs: Definitions, value_resolver: TemplatedValueResolver) -> Definitions:
target_selection = AssetSelection.from_string(self.target, include_sources=True)
target_keys = target_selection.resolve(defs.get_asset_graph())

mappable = [d for d in defs.assets or [] if isinstance(d, (AssetsDefinition, AssetSpec))]
mapped_assets = map_asset_specs(
lambda spec: self._apply_to_spec(spec) if spec.key in target_keys else spec, mappable
lambda spec: self.apply_to_spec(spec, value_resolver, target_keys),
mappable,
)

assets = [
Expand All @@ -66,8 +74,7 @@ class MergeAttributes(AssetSpecProcessor):
# default operation is "merge"
operation: Literal["merge"] = "merge"

def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec:
attributes = self._attributes()
def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec:
mergeable_attributes = {"metadata", "tags"}
merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes}
replace_attributes = {k: v for k, v in attributes.items() if k not in mergeable_attributes}
Expand All @@ -78,13 +85,13 @@ class ReplaceAttributes(AssetSpecProcessor):
# operation must be set explicitly
operation: Literal["replace"]

def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec:
return spec.replace_attributes(**self._attributes())
def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec:
return spec.replace_attributes(**attributes)


AssetAttributes = Sequence[
Annotated[
Union[MergeAttributes, ReplaceAttributes],
Field(union_mode="left_to_right"),
RenderingScope(Field(union_mode="left_to_right"), required_scope={"asset"}),
]
]
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _fn(context: AssetExecutionContext):

defs = Definitions(assets=[_fn])
for transform in self.asset_processors:
defs = transform.apply(defs)
defs = transform.apply(defs, context.templated_value_resolver)
return defs

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _fn(context: AssetExecutionContext, sling: SlingResource):

defs = Definitions(assets=[_fn], resources={"sling": self.resource})
for transform in self.asset_processors:
defs = transform.apply(defs)
defs = transform.apply(defs, context.templated_value_resolver)
return defs

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from dagster_components.core.component_rendering import (
RenderingScope,
TemplatedValueResolver,
_should_render,
preprocess_value,
has_required_scope,
)
from pydantic import BaseModel, Field, TypeAdapter

Expand Down Expand Up @@ -44,7 +43,9 @@ class Outer(BaseModel):
],
)
def test_should_render(path, expected: bool) -> None:
assert _should_render(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected
assert (
has_required_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected
)


def test_render() -> None:
Expand All @@ -61,7 +62,7 @@ def test_render() -> None:
}

renderer = TemplatedValueResolver(context={"foo_val": "foo", "bar_val": "bar"})
rendered_data = preprocess_value(renderer, data, Outer)
rendered_data = renderer.resolve_params(data, Outer)

assert rendered_data == {
"a": "foo",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest
from dagster import AssetKey, AssetSpec, Definitions
from dagster_components.core.dsl_schema import AssetAttributes, MergeAttributes, ReplaceAttributes
from dagster_components.core.dsl_schema import (
AssetAttributes,
MergeAttributes,
ReplaceAttributes,
TemplatedValueResolver,
)
from pydantic import BaseModel, TypeAdapter


Expand All @@ -20,7 +25,7 @@ class M(BaseModel):
def test_replace_attributes() -> None:
op = ReplaceAttributes(operation="replace", target="group:g2", tags={"newtag": "newval"})

newdefs = op.apply(defs)
newdefs = op.apply(defs, TemplatedValueResolver.default())
asset_graph = newdefs.get_asset_graph()
assert asset_graph.get(AssetKey("a")).tags == {}
assert asset_graph.get(AssetKey("b")).tags == {"newtag": "newval"}
Expand All @@ -30,13 +35,35 @@ def test_replace_attributes() -> None:
def test_merge_attributes() -> None:
op = MergeAttributes(operation="merge", target="group:g2", tags={"newtag": "newval"})

newdefs = op.apply(defs)
newdefs = op.apply(defs, TemplatedValueResolver.default())
asset_graph = newdefs.get_asset_graph()
assert asset_graph.get(AssetKey("a")).tags == {}
assert asset_graph.get(AssetKey("b")).tags == {"newtag": "newval"}
assert asset_graph.get(AssetKey("c")).tags == {"tag": "val", "newtag": "newval"}


def test_render_attributes_asset_context() -> None:
op = MergeAttributes(tags={"group_name_tag": "group__{{ asset.group_name }}"})

newdefs = op.apply(defs, TemplatedValueResolver.default().with_context(foo="theval"))
asset_graph = newdefs.get_asset_graph()
assert asset_graph.get(AssetKey("a")).tags == {"group_name_tag": "group__g1"}
assert asset_graph.get(AssetKey("b")).tags == {"group_name_tag": "group__g2"}
assert asset_graph.get(AssetKey("c")).tags == {"tag": "val", "group_name_tag": "group__g2"}


def test_render_attributes_custom_context() -> None:
op = ReplaceAttributes(
operation="replace", target="group:g2", tags={"a": "{{ foo }}", "b": "prefix_{{ foo }}"}
)

newdefs = op.apply(defs, TemplatedValueResolver.default().with_context(foo="theval"))
asset_graph = newdefs.get_asset_graph()
assert asset_graph.get(AssetKey("a")).tags == {}
assert asset_graph.get(AssetKey("b")).tags == {"a": "theval", "b": "prefix_theval"}
assert asset_graph.get(AssetKey("c")).tags == {"a": "theval", "b": "prefix_theval"}


@pytest.mark.parametrize(
"python,expected",
[
Expand Down