Skip to content

Commit

Permalink
[components] Update resolution deferral logic
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jan 16, 2025
1 parent b8b3315 commit 58bdbcc
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ShellScriptSchema(ComponentSchemaBaseModel):
script_runner: Annotated[
str,
ResolvableFieldInfo(
output_type=ScriptRunner, additional_scope={"get_script_runner"}
output_type=ScriptRunner, required_scope={"get_script_runner"}
),
]
# highlight-end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def get_registered_component_types_in_module(module: ModuleType) -> Iterable[typ
yield component


T = TypeVar("T")
T = TypeVar("T", bound=BaseModel)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,14 @@

from pydantic import BaseModel, ConfigDict, TypeAdapter

from dagster_components.core.schema.metadata import (
JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY,
get_resolution_metadata,
)
from dagster_components.core.schema.metadata import get_resolution_metadata
from dagster_components.core.schema.resolver import TemplatedValueResolver


class ComponentSchemaBaseModel(BaseModel):
"""Base class for models that are part of a component schema."""

model_config = ConfigDict(
json_schema_extra={JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY: True}, extra="forbid"
)
model_config = ConfigDict(extra="forbid")

def resolve_properties(self, value_resolver: TemplatedValueResolver) -> Mapping[str, Any]:
"""Returns a dictionary of resolved properties for this class."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
REF_BASE = "#/$defs/"
REF_TEMPLATE = f"{REF_BASE}{{model}}"
JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY = "dagster_defer_rendering"
JSON_SCHEMA_EXTRA_AVAILABLE_SCOPE_KEY = "dagster_available_scope"
JSON_SCHEMA_EXTRA_REQUIRED_SCOPE_KEY = "dagster_required_scope"


@dataclass
Expand All @@ -34,7 +34,7 @@ def __init__(
*,
output_type: Optional[type] = None,
post_process_fn: Optional[Callable[[Any], Any]] = None,
additional_scope: Optional[Set[str]] = None,
required_scope: Optional[Set[str]] = None,
):
self.resolution_metadata = (
ResolutionMetadata(output_type=output_type, post_process=post_process_fn)
Expand All @@ -43,8 +43,11 @@ def __init__(
)
super().__init__(
json_schema_extra={
JSON_SCHEMA_EXTRA_AVAILABLE_SCOPE_KEY: list(additional_scope or []),
JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY: True,
JSON_SCHEMA_EXTRA_REQUIRED_SCOPE_KEY: list(required_scope or []),
# defer resolution if the output type will change
**(
{JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY: True} if output_type is not None else {}
),
},
)

Expand Down Expand Up @@ -98,33 +101,34 @@ def _subschemas_on_path(
yield from _subschemas_on_path(rest, json_schema, inner)


def _should_defer_render(subschema: Mapping[str, Any]) -> bool:
def _get_should_defer_resolve(subschema: Mapping[str, Any]) -> bool:
raw = check.opt_inst(subschema.get(JSON_SCHEMA_EXTRA_DEFER_RENDERING_KEY), bool)
return raw or False


def _get_available_scope(subschema: Mapping[str, Any]) -> Set[str]:
raw = check.opt_inst(subschema.get(JSON_SCHEMA_EXTRA_AVAILABLE_SCOPE_KEY), list)
def _get_additional_required_scope(subschema: Mapping[str, Any]) -> Set[str]:
raw = check.opt_inst(subschema.get(JSON_SCHEMA_EXTRA_REQUIRED_SCOPE_KEY), list)
return set(raw) if raw else set()


def allow_resolve(
valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any]
) -> bool:
"""Given a valpath and the json schema of a given target type, determines if there is a rendering scope
required to render the value at the given path.
"""Given a valpath and the json schema of a given target type, determines if this value can be
resolved eagerly. This can only happen if the output type of the resolved value is unchanged,
and there is no additional scope required for resolution.
"""
for subschema in _subschemas_on_path(valpath, json_schema, subschema):
if _should_defer_render(subschema):
if _get_should_defer_resolve(subschema) or _get_additional_required_scope(subschema):
return False
return True


def get_available_scope(
valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any]
def get_required_scope(
valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any]
) -> Set[str]:
"""Given a valpath and the json schema of a given target type, determines the available rendering scope."""
available_scope = set()
for subschema in _subschemas_on_path(valpath, json_schema, subschema):
available_scope |= _get_available_scope(subschema)
return available_scope
required_scope = set()
for subschema in _subschemas_on_path(valpath, json_schema, json_schema):
required_scope |= _get_additional_required_scope(subschema)
return required_scope
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ def automation_condition_scope() -> Mapping[str, Any]:
}


ShouldResolveFn = Callable[[Sequence[Union[str, int]]], bool]


@record
class TemplatedValueResolver:
scope: Mapping[str, Any]
Expand All @@ -50,23 +47,23 @@ def _resolve_obj(
self,
obj: Any,
valpath: Optional[Sequence[Union[str, int]]],
should_render: Callable[[Sequence[Union[str, int]]], bool],
should_resolve: Callable[[Sequence[Union[str, int]]], bool],
) -> Any:
"""Recursively resolves templated values in a nested object, based on the provided should_render function."""
if valpath is not None and not should_render(valpath):
"""Recursively resolves templated values in a nested object, based on the provided should_resolve function."""
if valpath is not None and not should_resolve(valpath):
return obj
elif isinstance(obj, dict):
# render all values in the dict
# resolve all values in the dict
return {
k: self._resolve_obj(
v, [*valpath, k] if valpath is not None else None, should_render
v, [*valpath, k] if valpath is not None else None, should_resolve
)
for k, v in obj.items()
}
elif isinstance(obj, list):
# render all values in the list
# resolve all values in the list
return [
self._resolve_obj(v, [*valpath, i] if valpath is not None else None, should_render)
self._resolve_obj(v, [*valpath, i] if valpath is not None else None, should_resolve)
for i, v in enumerate(obj)
]
else:
Expand All @@ -76,15 +73,12 @@ def resolve_obj(self, val: Any) -> Any:
"""Recursively resolves templated values in a nested object."""
return self._resolve_obj(val, None, lambda _: True)

def resolve_params(self, val: T, target_type: type) -> T:
"""Given a raw params value, preprocesses it by rendering any templated values that are not marked as deferred in the target_type's json schema."""
json_schema = (
target_type.model_json_schema() if issubclass(target_type, BaseModel) else None
def resolve_params(self, val: T, target_type: type[BaseModel]) -> T:
"""Given a raw params value, preprocesses it by resolving any templated values that are not marked
as deferred in the target_type's json schema.
"""
json_schema = target_type.model_json_schema()
should_resolve = functools.partial(
allow_resolve, json_schema=json_schema, subschema=json_schema
)
if json_schema is None:
should_render = lambda _: True
else:
should_render = functools.partial(
allow_resolve, json_schema=json_schema, subschema=json_schema
)
return self._resolve_obj(val, [], should_render=should_render)
return self._resolve_obj(val, [], should_resolve=should_resolve)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DbtProjectParams(BaseModel):
dbt: DbtCliResource
op: Optional[OpSpecBaseModel] = None
asset_attributes: Annotated[
Optional[AssetAttributesModel], ResolvableFieldInfo(additional_scope={"node"})
Optional[AssetAttributesModel], ResolvableFieldInfo(required_scope={"node"})
] = None
transforms: Optional[Sequence[AssetSpecTransformModel]] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SlingReplicationParams(BaseModel):
path: str
op: Optional[OpSpecBaseModel] = None
asset_attributes: Annotated[
Optional[AssetAttributesModel], ResolvableFieldInfo(additional_scope={"stream_definition"})
Optional[AssetAttributesModel], ResolvableFieldInfo(required_scope={"stream_definition"})
] = None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ComplexAssetParams(BaseModel):
value: str
op: Optional[OpSpecBaseModel] = None
asset_attributes: Annotated[
Optional[AssetAttributesModel], ResolvableFieldInfo(additional_scope={"node"})
Optional[AssetAttributesModel], ResolvableFieldInfo(required_scope={"node"})
] = None
asset_transforms: Optional[Sequence[AssetSpecTransformModel]] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@

import pytest
from dagster_components import ComponentSchemaBaseModel, ResolvableFieldInfo, TemplatedValueResolver
from dagster_components.core.schema.metadata import allow_resolve, get_available_scope
from dagster_components.core.schema.metadata import allow_resolve, get_required_scope
from pydantic import BaseModel, Field, TypeAdapter, ValidationError


class InnerRendered(ComponentSchemaBaseModel):
a: Optional[str] = None
a: Annotated[Optional[str], ResolvableFieldInfo(required_scope={"deferred"})] = None


class Container(BaseModel):
a: str
inner: InnerRendered
inner_scoped: Annotated[InnerRendered, ResolvableFieldInfo(additional_scope={"c", "d"})] = (
Field(default_factory=InnerRendered)
inner_scoped: Annotated[InnerRendered, ResolvableFieldInfo(required_scope={"c", "d"})] = Field(
default_factory=InnerRendered
)


Expand All @@ -25,36 +25,38 @@ class Outer(BaseModel):
container: Container
container_optional: Optional[Container] = None
container_optional_scoped: Annotated[
Optional[Container], ResolvableFieldInfo(additional_scope={"a", "b"})
Optional[Container], ResolvableFieldInfo(required_scope={"a", "b"})
] = None
inner_seq: Sequence[InnerRendered]
inner_optional: Optional[InnerRendered] = None
inner_optional_seq: Optional[Sequence[InnerRendered]] = None
transformed: Annotated[Optional[str], ResolvableFieldInfo(output_type=Optional[int])] = None


@pytest.mark.parametrize(
"path,expected",
[
(["a"], True),
(["inner"], False),
(["inner"], True),
(["inner", "a"], False),
(["container", "a"], True),
(["container", "inner"], False),
(["container", "inner"], True),
(["container", "inner", "a"], False),
(["container_optional", "a"], True),
(["container_optional", "inner"], False),
(["container_optional", "inner"], True),
(["container_optional", "inner", "a"], False),
(["container_optional_scoped"], False),
(["container_optional_scoped", "inner", "a"], False),
(["container_optional_scoped", "inner_scoped", "a"], False),
(["inner_seq"], True),
(["inner_seq", 0], False),
(["inner_seq", 0], True),
(["inner_seq", 0, "a"], False),
(["inner_optional"], True),
(["inner_optional", "a"], False),
(["inner_optional_seq"], True),
(["inner_optional_seq", 0], False),
(["inner_optional_seq", 0], True),
(["inner_optional_seq", 0, "a"], False),
(["transformed"], False),
],
)
def test_allow_render(path, expected: bool) -> None:
Expand All @@ -65,19 +67,17 @@ def test_allow_render(path, expected: bool) -> None:
"path,expected",
[
(["a"], set()),
(["inner", "a"], set()),
(["container_optional", "inner", "a"], set()),
(["inner", "a"], {"deferred"}),
(["container_optional", "inner", "a"], {"deferred"}),
(["inner_seq"], set()),
(["container_optional_scoped"], {"a", "b"}),
(["container_optional_scoped", "inner"], {"a", "b"}),
(["container_optional_scoped", "inner_scoped"], {"a", "b", "c", "d"}),
(["container_optional_scoped", "inner_scoped", "a"], {"a", "b", "c", "d"}),
(["container_optional_scoped", "inner_scoped", "a"], {"a", "b", "c", "d", "deferred"}),
],
)
def test_get_available_scope(path, expected: Set[str]) -> None:
assert (
get_available_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected
)
def test_get_required_scope(path, expected: Set[str]) -> None:
assert get_required_scope(path, Outer.model_json_schema()) == expected


def test_render() -> None:
Expand Down

0 comments on commit 58bdbcc

Please sign in to comment.