Skip to content

Commit

Permalink
[components] Refactor component load behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jan 23, 2025
1 parent 145aa32 commit 5ee4d47
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def get_schema(cls) -> type[ShellScriptSchema]:
# highlight-end

@classmethod
def load(cls, load_context: ComponentLoadContext) -> "ShellCommand":
return cls(params=load_context.load_params(cls.get_schema()))
def load(
cls, params: ShellScriptSchema, load_context: ComponentLoadContext
) -> "ShellCommand":
return cls(params=params)

def build_defs(self, load_context: ComponentLoadContext) -> dg.Definitions:
resolved_asset_attributes = self.params.asset_attributes.resolve_properties(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
get_component_type_name,
)
from dagster_components.core.component_defs_builder import (
component_type_from_yaml_decl,
load_components_from_context,
path_to_decl_node,
resolve_decl_node_to_yaml_decls,
)
Expand Down Expand Up @@ -186,11 +184,9 @@ def check_component_command(ctx: click.Context, paths: Sequence[str]) -> None:
templated_value_resolver=TemplatedValueResolver.default(),
)
try:
load_components_from_context(clc)
decl_node.load(clc)
except ValidationError as e:
component_type = component_type_from_yaml_decl(
context.component_registry, yaml_decl
)
component_type = yaml_decl.get_component_type(context.component_registry)
validation_errors.append((component_type, e))

if validation_errors:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from dagster_components.core.schema.resolver import TemplatedValueResolver


class ComponentDeclNode: ...
class ComponentDeclNode(ABC):
@abstractmethod
def load(self, context: "ComponentLoadContext") -> Sequence["Component"]: ...


class Component(ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, Optional, TypeVar, Union
Expand All @@ -12,7 +13,15 @@
from dagster._utils.yaml_utils import parse_yaml_with_source_positions
from pydantic import BaseModel, TypeAdapter

from dagster_components.core.component import ComponentDeclNode, ComponentLoadContext
from dagster_components.core.component import (
Component,
ComponentDeclNode,
ComponentLoadContext,
ComponentTypeRegistry,
get_component_type_name,
is_registered_component_type,
)
from dagster_components.utils import load_module_from_path


class ComponentFileModel(BaseModel):
Expand All @@ -30,7 +39,16 @@ class YamlComponentDecl(ComponentDeclNode):
source_position_tree: Optional[SourcePositionTree] = None

@staticmethod
def from_path(component_file_path: Path) -> "YamlComponentDecl":
def component_file_path(path: Path) -> Path:
return path / "component.yaml"

@staticmethod
def exists_at(path: Path) -> bool:
return YamlComponentDecl.component_file_path(path).exists()

@staticmethod
def from_path(path: Path) -> "YamlComponentDecl":
component_file_path = YamlComponentDecl.component_file_path(path)
parsed = parse_yaml_with_source_positions(
component_file_path.read_text(), str(component_file_path)
)
Expand All @@ -39,11 +57,36 @@ def from_path(component_file_path: Path) -> "YamlComponentDecl":
)

return YamlComponentDecl(
path=component_file_path.parent,
path=path,
component_file_model=obj,
source_position_tree=parsed.source_position_tree,
)

def get_component_type(self, registry: ComponentTypeRegistry) -> type[Component]:
parsed_defs = self.component_file_model
if parsed_defs.type.startswith("."):
component_registry_key = parsed_defs.type[1:]

# Iterate over Python files in the folder
for py_file in self.path.glob("*.py"):
module_name = py_file.stem

module = load_module_from_path(module_name, self.path / f"{module_name}.py")

for _name, obj in inspect.getmembers(module, inspect.isclass):
assert isinstance(obj, type)
if (
is_registered_component_type(obj)
and get_component_type_name(obj) == component_registry_key
):
return obj

raise Exception(
f"Could not find component type {component_registry_key} in {self.path}"
)

return registry.get(parsed_defs.type)

def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T:
with pushd(str(self.path)):
raw_params = self.component_file_model.params
Expand All @@ -60,12 +103,26 @@ def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T
else:
return TypeAdapter(params_schema).validate_python(preprocessed_params)

def load(self, context: ComponentLoadContext) -> Sequence[Component]:
component_type = self.get_component_type(context.registry)
component_schema = component_type.get_schema()
context = context.with_rendering_scope(component_type.get_additional_scope())
loaded_params = self.get_params(context, component_schema) if component_schema else None
return [component_type.load(loaded_params, context)]


@record
class ComponentFolder(ComponentDeclNode):
path: Path
sub_decls: Sequence[Union[YamlComponentDecl, "ComponentFolder"]]

def load(self, context: ComponentLoadContext) -> Sequence[Component]:
components = []
for sub_decl in self.sub_decls:
sub_context = context.for_decl_node(sub_decl)
components.extend(sub_decl.load(sub_context))
return components


def path_to_decl_node(path: Path) -> Optional[ComponentDeclNode]:
# right now, we only support two types of components, both of which are folders
Expand All @@ -75,10 +132,8 @@ def path_to_decl_node(path: Path) -> Optional[ComponentDeclNode]:
if not path.is_dir():
return None

component_path = path / "component.yaml"

if component_path.exists():
return YamlComponentDecl.from_path(component_path)
if YamlComponentDecl.exists_at(path):
return YamlComponentDecl.from_path(path)

subs = []
for subpath in path.iterdir():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import importlib
import importlib.util
import inspect
from collections.abc import Mapping, Sequence
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Optional

from dagster._utils.warnings import suppress_dagster_warnings
Expand All @@ -13,8 +9,6 @@
ComponentLoadContext,
ComponentTypeRegistry,
TemplatedValueResolver,
get_component_type_name,
is_registered_component_type,
)
from dagster_components.core.component_decl_builder import (
ComponentDeclNode,
Expand All @@ -28,19 +22,6 @@
from dagster._core.definitions.definitions_class import Definitions


def load_module_from_path(module_name, path) -> ModuleType:
# Create a spec from the file path
spec = importlib.util.spec_from_file_location(module_name, path)
if spec is None:
raise ImportError(f"Cannot create a module spec from path: {path}")

# Create and load the module
module = importlib.util.module_from_spec(spec)
assert spec.loader, "Must have a loader"
spec.loader.exec_module(module)
return module


def resolve_decl_node_to_yaml_decls(decl: ComponentDeclNode) -> list[YamlComponentDecl]:
if isinstance(decl, YamlComponentDecl):
return [decl]
Expand All @@ -53,57 +34,12 @@ def resolve_decl_node_to_yaml_decls(decl: ComponentDeclNode) -> list[YamlCompone
raise NotImplementedError(f"Unknown component type {decl}")


def load_components_from_context(context: ComponentLoadContext) -> Sequence[Component]:
node = context.decl_node
if isinstance(node, YamlComponentDecl):
component_type = component_type_from_yaml_decl(context.registry, node)
component_schema = component_type.get_schema()
context = context.with_rendering_scope(component_type.get_additional_scope())
loaded_params = node.get_params(context, component_schema) if component_schema else None
return [component_type.load(loaded_params, context)]
elif isinstance(node, ComponentFolder):
components = []
for sub_decl in node.sub_decls:
components.extend(load_components_from_context(context.for_decl_node(sub_decl)))
return components

raise NotImplementedError(f"Unknown component type {node}")


def component_type_from_yaml_decl(
registry: ComponentTypeRegistry, decl_node: YamlComponentDecl
) -> type[Component]:
parsed_defs = decl_node.component_file_model
if parsed_defs.type.startswith("."):
component_registry_key = parsed_defs.type[1:]

# Iterate over Python files in the folder
for py_file in decl_node.path.glob("*.py"):
module_name = py_file.stem

module = load_module_from_path(module_name, decl_node.path / f"{module_name}.py")

for _name, obj in inspect.getmembers(module, inspect.isclass):
assert isinstance(obj, type)
if (
is_registered_component_type(obj)
and get_component_type_name(obj) == component_registry_key
):
return obj

raise Exception(
f"Could not find component type {component_registry_key} in {decl_node.path}"
)

return registry.get(parsed_defs.type)


def build_components_from_component_folder(
context: ComponentLoadContext, path: Path
) -> Sequence[Component]:
component_folder = path_to_decl_node(path)
assert isinstance(component_folder, ComponentFolder)
return load_components_from_context(context.for_decl_node(component_folder))
return component_folder.load(context.for_decl_node(component_folder))


def build_defs_from_component_path(
Expand All @@ -122,7 +58,7 @@ def build_defs_from_component_path(
decl_node=decl_node,
templated_value_resolver=TemplatedValueResolver.default(),
)
components = load_components_from_context(context)
components = decl_node.load(context)
return defs_from_components(resources=resources, context=context, components=components)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Any, Optional

from dagster._core.definitions.asset_key import AssetKey
Expand Down Expand Up @@ -121,3 +122,16 @@ def get_automation_condition(self, obj: Any) -> Optional[AutomationCondition]:
)

return WrappedTranslator


def load_module_from_path(module_name, path) -> ModuleType:
# Create a spec from the file path
spec = importlib.util.spec_from_file_location(module_name, path)
if spec is None:
raise ImportError(f"Cannot create a module spec from path: {path}")

# Create and load the module
module = importlib.util.module_from_spec(spec)
assert spec.loader, "Must have a loader"
spec.loader.exec_module(module)
return module
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def test_render_vars_root(dbt_path: Path) -> None:
},
),
)
comp = DbtProjectComponent.load(context=script_load_context(decl_node))
context = script_load_context(decl_node)
params = decl_node.get_params(context, DbtProjectComponent.get_schema())
comp = DbtProjectComponent.load(params=params, context=context)
assert get_asset_keys(comp) == JAFFLE_SHOP_KEYS
defs: Definitions = comp.build_defs(script_load_context())
for key in get_asset_keys(comp):
Expand All @@ -155,5 +157,7 @@ def test_render_vars_asset_key(dbt_path: Path) -> None:
},
),
)
comp = DbtProjectComponent.load(context=script_load_context(decl_node))
context = script_load_context(decl_node)
params = decl_node.get_params(context, DbtProjectComponent.get_schema())
comp = DbtProjectComponent.load(params=params, context=context)
assert get_asset_keys(comp) == JAFFLE_SHOP_KEYS_WITH_PREFIX
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ class MyComponentSchema(ComponentSchemaBaseModel):
@component_type
class MyComponent(Component):
name = "my_component"
params_schema = MyComponentSchema

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
context.load_params(cls.params_schema)
def get_schema(cls) -> type[MyComponentSchema]:
return MyComponentSchema

@classmethod
def load(cls, params: MyComponentSchema, context: ComponentLoadContext) -> Self:
return cls()

def build_defs(self, context: ComponentLoadContext) -> Definitions:
Expand All @@ -40,11 +42,13 @@ class MyNestedComponentSchema(ComponentSchemaBaseModel):
@component_type
class MyNestedComponent(Component):
name = "my_nested_component"
params_schema = MyNestedComponentSchema

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
context.load_params(cls.params_schema)
def get_schema(cls) -> type[MyNestedComponentSchema]:
return MyNestedComponentSchema

@classmethod
def load(cls, params: MyComponentSchema, context: ComponentLoadContext) -> Self:
return cls()

def build_defs(self, context: ComponentLoadContext) -> Definitions:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def get_schema(cls):
return CustomScopeParams

@classmethod
def load(cls, context: ComponentLoadContext):
loaded_params = context.load_params(cls.get_schema())
return cls(attributes=loaded_params.attributes)
def load(cls, params: CustomScopeParams, context: ComponentLoadContext):
return cls(attributes=params.attributes)

def build_defs(self, context: ComponentLoadContext):
return Definitions(assets=[AssetSpec(key="key", **self.attributes)])
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ class {{ component_type_class_name }}(Component):
return DefaultComponentScaffolder()

@classmethod
def load(cls, context: ComponentLoadContext) -> "{{ component_type_class_name }}":
loaded_params = context.load_params(cls.get_component_schema_type())

def load(
cls,
params: {{ component_type_class_name }}Params,
context: ComponentLoadContext,
) -> "{{ component_type_class_name }}":
# Add logic for mapping schema parameters to constructor args here.
return cls()

Expand Down

0 comments on commit 5ee4d47

Please sign in to comment.