Skip to content

Commit

Permalink
[components] Explicitly pass params
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jan 23, 2025
1 parent 3c8e74a commit 145aa32
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Any, ClassVar, Optional, TypedDict, TypeVar, Union, cast
from typing import Any, ClassVar, Optional, TypedDict, TypeVar, Union

import click
from dagster import _check as check
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.errors import DagsterError
from dagster._utils import pushd, snakecase
from dagster._utils.pydantic_yaml import enrich_validation_errors_with_source_position
from pydantic import BaseModel, TypeAdapter
from dagster._utils import snakecase
from pydantic import BaseModel
from typing_extensions import Self

from dagster_components.core.component_scaffolder import (
Expand Down Expand Up @@ -59,7 +58,7 @@ def build_defs(self, context: "ComponentLoadContext") -> Definitions: ...

@classmethod
@abstractmethod
def load(cls, context: "ComponentLoadContext") -> Self: ...
def load(cls, params: Optional[BaseModel], context: "ComponentLoadContext") -> Self: ...

@classmethod
def get_metadata(cls) -> "ComponentTypeInternalMetadata":
Expand Down Expand Up @@ -253,31 +252,6 @@ def with_rendering_scope(self, rendering_scope: Mapping[str, Any]) -> "Component
def for_decl_node(self, decl_node: ComponentDeclNode) -> "ComponentLoadContext":
return dataclasses.replace(self, decl_node=decl_node)

def _raw_params(self) -> Optional[Mapping[str, Any]]:
from dagster_components.core.component_decl_builder import YamlComponentDecl

if not isinstance(self.decl_node, YamlComponentDecl):
check.failed(f"Unsupported decl_node type {type(self.decl_node)}")
return self.decl_node.component_file_model.params

def load_params(self, params_schema: type[T]) -> T:
from dagster_components.core.component_decl_builder import YamlComponentDecl

with pushd(str(self.path)):
preprocessed_params = self.templated_value_resolver.resolve_params(
self._raw_params(), params_schema
)
yaml_decl = cast(YamlComponentDecl, self.decl_node)

if yaml_decl.source_position_tree:
source_position_tree_of_params = yaml_decl.source_position_tree.children["params"]
with enrich_validation_errors_with_source_position(
source_position_tree_of_params, ["params"]
):
return TypeAdapter(params_schema).validate_python(preprocessed_params)
else:
return TypeAdapter(params_schema).validate_python(preprocessed_params)


COMPONENT_REGISTRY_KEY_ATTR = "__dagster_component_registry_key"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Optional, TypeVar, Union

from dagster._record import record
from dagster._utils.pydantic_yaml import _parse_and_populate_model_with_annotated_errors
from dagster._utils import pushd
from dagster._utils.pydantic_yaml import (
_parse_and_populate_model_with_annotated_errors,
enrich_validation_errors_with_source_position,
)
from dagster._utils.source_position import SourcePositionTree
from dagster._utils.yaml_utils import parse_yaml_with_source_positions
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter

from dagster_components.core.component import ComponentDeclNode
from dagster_components.core.component import ComponentDeclNode, ComponentLoadContext


class ComponentFileModel(BaseModel):
type: str
params: Optional[Mapping[str, Any]] = None


T = TypeVar("T", bound=BaseModel)


@record
class YamlComponentDecl(ComponentDeclNode):
path: Path
Expand All @@ -37,6 +44,22 @@ def from_path(component_file_path: Path) -> "YamlComponentDecl":
source_position_tree=parsed.source_position_tree,
)

def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T:
with pushd(str(self.path)):
raw_params = self.component_file_model.params
preprocessed_params = context.templated_value_resolver.resolve_params(
raw_params, params_schema
)

if self.source_position_tree:
source_position_tree_of_params = self.source_position_tree.children["params"]
with enrich_validation_errors_with_source_position(
source_position_tree_of_params, ["params"]
):
return TypeAdapter(params_schema).validate_python(preprocessed_params)
else:
return TypeAdapter(params_schema).validate_python(preprocessed_params)


@record
class ComponentFolder(ComponentDeclNode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,20 @@ def resolve_decl_node_to_yaml_decls(decl: ComponentDeclNode) -> list[YamlCompone


def load_components_from_context(context: ComponentLoadContext) -> Sequence[Component]:
if isinstance(context.decl_node, YamlComponentDecl):
component_type = component_type_from_yaml_decl(context.registry, context.decl_node)
node = context.decl_node
if isinstance(node, YamlComponentDecl):
component_type = component_type_from_yaml_decl(context.registry, node)
component_schema = component_type.get_schema()
context = context.with_rendering_scope(component_type.get_additional_scope())
return [component_type.load(context)]
elif isinstance(context.decl_node, ComponentFolder):
loaded_params = node.get_params(context, component_schema) if component_schema else None
return [component_type.load(loaded_params, context)]
elif isinstance(node, ComponentFolder):
components = []
for sub_decl in context.decl_node.sub_decls:
for sub_decl in node.sub_decls:
components.extend(load_components_from_context(context.for_decl_node(sub_decl)))
return components

raise NotImplementedError(f"Unknown component type {context.decl_node}")
raise NotImplementedError(f"Unknown component type {node}")


def component_type_from_yaml_decl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,16 @@ def get_scaffolder(cls) -> "DbtProjectComponentScaffolder":
return DbtProjectComponentScaffolder()

@classmethod
def get_schema(cls):
def get_schema(cls) -> type[DbtProjectParams]:
return DbtProjectParams

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
loaded_params = context.load_params(cls.get_schema())

def load(cls, params: DbtProjectParams, context: ComponentLoadContext) -> Self:
return cls(
dbt_resource=loaded_params.dbt,
op_spec=loaded_params.op,
asset_attributes=loaded_params.asset_attributes,
transforms=loaded_params.transforms or [],
dbt_resource=params.dbt,
op_spec=params.op,
asset_attributes=params.asset_attributes,
transforms=params.transforms or [],
value_resolver=context.templated_value_resolver,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@ def get_scaffolder(cls) -> DefinitionsComponentScaffolder:
return DefinitionsComponentScaffolder()

@classmethod
def get_schema(cls):
def get_schema(cls) -> type[DefinitionsParamSchema]:
return DefinitionsParamSchema

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
# all paths should be resolved relative to the directory we're in
loaded_params = context.load_params(cls.get_schema())

return cls(definitions_path=Path(loaded_params.definitions_path or "definitions.py"))
def load(cls, params: DefinitionsParamSchema, context: ComponentLoadContext) -> Self:
return cls(definitions_path=Path(params.definitions_path or "definitions.py"))

def build_defs(self, context: ComponentLoadContext) -> Definitions:
with pushd(str(context.path)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def introspect_from_path(path: Path) -> "PipesSubprocessScriptCollection":
return PipesSubprocessScriptCollection(dirpath=path, path_specs=path_specs)

@classmethod
def get_schema(cls):
def get_schema(cls) -> type[PipesSubprocessScriptCollectionParams]:
return PipesSubprocessScriptCollectionParams

@classmethod
def load(cls, context: ComponentLoadContext) -> "PipesSubprocessScriptCollection":
loaded_params = context.load_params(cls.get_schema())

def load(
cls, params: PipesSubprocessScriptCollectionParams, context: ComponentLoadContext
) -> "PipesSubprocessScriptCollection":
path_specs = {}
for script in loaded_params.scripts:
for script in params.scripts:
script_path = context.path / script.path
if not script_path.exists():
raise FileNotFoundError(f"Script {script_path} does not exist")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,16 @@ def get_scaffolder(cls) -> ComponentScaffolder:
return SlingReplicationComponentScaffolder()

@classmethod
def get_schema(cls):
def get_schema(cls) -> type[SlingReplicationCollectionParams]:
return SlingReplicationCollectionParams

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
loaded_params = context.load_params(cls.get_schema())
def load(cls, params: SlingReplicationCollectionParams, context: ComponentLoadContext) -> Self:
return cls(
dirpath=context.path,
resource=loaded_params.sling or SlingResource(),
sling_replications=loaded_params.replications,
transforms=loaded_params.transforms or [],
resource=params.sling or SlingResource(),
sling_replications=params.replications,
transforms=params.transforms or [],
)

def build_replication_asset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def test_python_params(dbt_path: Path) -> None:
},
),
)
component = DbtProjectComponent.load(context=script_load_context(decl_node))
context = script_load_context(decl_node)
component = DbtProjectComponent.load(
params=decl_node.get_params(context, DbtProjectComponent.get_schema()),
context=context,
)
assert get_asset_keys(component) == JAFFLE_SHOP_KEYS
defs = component.build_defs(script_load_context())
assert defs.get_assets_def("stg_customers").op.name == "some_op"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_python_params(sling_path: Path) -> None:
),
)
context = script_load_context(decl_node)
component = SlingReplicationCollectionComponent.load(context)
params = decl_node.get_params(context, SlingReplicationCollectionComponent.get_schema())
component = SlingReplicationCollectionComponent.load(params, context)

replications = component.sling_replications
assert len(replications) == 1
Expand Down Expand Up @@ -108,7 +109,8 @@ def test_python_params_op_name(sling_path: Path) -> None:
),
)
context = script_load_context(decl_node)
component = SlingReplicationCollectionComponent.load(context=context)
params = decl_node.get_params(context, SlingReplicationCollectionComponent.get_schema())
component = SlingReplicationCollectionComponent.load(params, context=context)

replications = component.sling_replications
assert len(replications) == 1
Expand Down Expand Up @@ -137,7 +139,8 @@ def test_python_params_op_tags(sling_path: Path) -> None:
),
)
context = script_load_context(decl_node)
component = SlingReplicationCollectionComponent.load(context=context)
params = decl_node.get_params(context, SlingReplicationCollectionComponent.get_schema())
component = SlingReplicationCollectionComponent.load(params=params, context=context)
replications = component.sling_replications
assert len(replications) == 1
op_spec = replications[0].op
Expand Down Expand Up @@ -175,7 +178,11 @@ def execute(
params={"sling": {}, "replications": [{"path": "./replication.yaml"}]},
),
)
params = decl_node.get_params(
script_load_context(decl_node), DebugSlingReplicationComponent.get_schema()
)
component_inst = DebugSlingReplicationComponent.load(
params=params,
context=script_load_context(decl_node),
)
assert get_asset_keys(component_inst) == {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def test_python_params_node_rename(dbt_path: Path) -> None:
},
),
)
component = DbtProjectComponent.load(
context=script_load_context(decl_node),
)
context = script_load_context(decl_node)
params = decl_node.get_params(context, DbtProjectComponent.get_schema())
component = DbtProjectComponent.load(params=params, context=context)
assert get_asset_keys(component) == JAFFLE_SHOP_KEYS_WITH_PREFIX


Expand All @@ -93,7 +93,9 @@ def test_python_params_group(dbt_path: Path) -> None:
},
),
)
comp = DbtProjectComponent.load(context=script_load_context(decl_node))
context = script_load_context(decl_node)
params = decl_node.get_params(context, DbtProjectComponent.get_schema())
comp = DbtProjectComponent.load(params=params, context=context)
assert get_asset_keys(comp) == JAFFLE_SHOP_KEYS
defs: Definitions = comp.build_defs(script_load_context(None))
for key in get_asset_keys(comp):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from pathlib import Path

from dagster import AssetKey
from dagster_components.core.component_decl_builder import ComponentFileModel
from dagster_components import AssetAttributesModel
from dagster_components.core.component_decl_builder import ComponentFileModel, YamlComponentDecl
from dagster_components.core.component_defs_builder import (
YamlComponentDecl,
build_components_from_component_folder,
build_defs_from_component_path,
defs_from_components,
)
from dagster_components.lib.pipes_subprocess_script_collection import (
PipesSubprocessScriptCollection,
PipesSubprocessScriptCollectionParams,
PipesSubprocessScriptParams,
)

from dagster_components_tests.utils import assert_assets, get_asset_keys, script_load_context
Expand All @@ -25,32 +27,37 @@ def test_python_native() -> None:


def test_python_params() -> None:
component_decl = YamlComponentDecl(
path=LOCATION_PATH / "components" / "scripts",
component_file_model=ComponentFileModel(
type="pipes_subprocess_script_collection",
params={
"scripts": [
{
"path": "script_one.py",
"assets": [
{
"key": "a",
"automation_condition": "{{ automation_condition.eager() }}",
},
{
"key": "b",
"automation_condition": "{{ automation_condition.on_cron('@daily') }}",
"deps": ["up1", "up2"],
},
],
},
{"path": "subdir/script_three.py", "assets": [{"key": "key_override"}]},
]
},
params = PipesSubprocessScriptCollectionParams(
scripts=[
PipesSubprocessScriptParams(
path="script_one.py",
assets=[
AssetAttributesModel(
key="a", automation_condition="{{ automation_condition.eager() }}"
),
AssetAttributesModel(
key="b",
automation_condition="{{ automation_condition.on_cron('@daily') }}",
deps=["up1", "up2"],
),
],
),
PipesSubprocessScriptParams(
path="subdir/script_three.py",
assets=[AssetAttributesModel(key="key_override")],
),
]
)
component = PipesSubprocessScriptCollection.load(
params=params,
# TODO: we should use a PythonComponentDecl here instead
context=script_load_context(
YamlComponentDecl(
path=Path(LOCATION_PATH / "components" / "scripts"),
component_file_model=ComponentFileModel(type="."),
)
),
)
component = PipesSubprocessScriptCollection.load(context=script_load_context(component_decl))
assert get_asset_keys(component) == {
AssetKey("a"),
AssetKey("b"),
Expand Down

0 comments on commit 145aa32

Please sign in to comment.