Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[components] Refactor component load behavior #27181

Open
wants to merge 1 commit into
base: 01-16-_components_explicitly_pass_params
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
get_component_type_name,
)
from dagster_components.core.component_defs_builder import (
component_type_from_yaml_decl,
load_components_from_context,
path_to_decl_node,
resolve_decl_node_to_yaml_decls,
)
Expand Down Expand Up @@ -186,11 +184,9 @@ def check_component_command(ctx: click.Context, paths: Sequence[str]) -> None:
templated_value_resolver=TemplatedValueResolver.default(),
)
try:
load_components_from_context(clc)
decl_node.load(clc)
except ValidationError as e:
component_type = component_type_from_yaml_decl(
context.component_registry, yaml_decl
)
component_type = yaml_decl.get_component_type(context.component_registry)
validation_errors.append((component_type, e))

if validation_errors:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from dagster_components.core.schema.resolver import TemplatedValueResolver


class ComponentDeclNode: ...
class ComponentDeclNode(ABC):
@abstractmethod
def load(self, context: "ComponentLoadContext") -> Sequence["Component"]: ...


class Component(ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, Optional, TypeVar, Union
Expand All @@ -12,7 +13,15 @@
from dagster._utils.yaml_utils import parse_yaml_with_source_positions
from pydantic import BaseModel, TypeAdapter

from dagster_components.core.component import ComponentDeclNode, ComponentLoadContext
from dagster_components.core.component import (
Component,
ComponentDeclNode,
ComponentLoadContext,
ComponentTypeRegistry,
get_component_type_name,
is_registered_component_type,
)
from dagster_components.utils import load_module_from_path


class ComponentFileModel(BaseModel):
Expand All @@ -30,7 +39,16 @@ class YamlComponentDecl(ComponentDeclNode):
source_position_tree: Optional[SourcePositionTree] = None

@staticmethod
def from_path(component_file_path: Path) -> "YamlComponentDecl":
def component_file_path(path: Path) -> Path:
return path / "component.yaml"

@staticmethod
def exists_at(path: Path) -> bool:
return YamlComponentDecl.component_file_path(path).exists()

@staticmethod
def from_path(path: Path) -> "YamlComponentDecl":
component_file_path = YamlComponentDecl.component_file_path(path)
parsed = parse_yaml_with_source_positions(
component_file_path.read_text(), str(component_file_path)
)
Expand All @@ -39,11 +57,36 @@ def from_path(component_file_path: Path) -> "YamlComponentDecl":
)

return YamlComponentDecl(
path=component_file_path.parent,
path=path,
component_file_model=obj,
source_position_tree=parsed.source_position_tree,
)

def get_component_type(self, registry: ComponentTypeRegistry) -> type[Component]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: get_ belies that this is doing some real work (parsing python files), maybe resolve_component_type?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth a brief docstr to explain that this does local lookup in addition to the expected registry lookup

parsed_defs = self.component_file_model
if parsed_defs.type.startswith("."):
component_registry_key = parsed_defs.type[1:]

# Iterate over Python files in the folder
for py_file in self.path.glob("*.py"):
module_name = py_file.stem

module = load_module_from_path(module_name, self.path / f"{module_name}.py")

for _name, obj in inspect.getmembers(module, inspect.isclass):
assert isinstance(obj, type)
if (
is_registered_component_type(obj)
and get_component_type_name(obj) == component_registry_key
):
return obj

raise Exception(
f"Could not find component type {component_registry_key} in {self.path}"
)

return registry.get(parsed_defs.type)

def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T:
with pushd(str(self.path)):
raw_params = self.component_file_model.params
Expand All @@ -60,12 +103,26 @@ def get_params(self, context: ComponentLoadContext, params_schema: type[T]) -> T
else:
return TypeAdapter(params_schema).validate_python(preprocessed_params)

def load(self, context: ComponentLoadContext) -> Sequence[Component]:
component_type = self.get_component_type(context.registry)
component_schema = component_type.get_schema()
context = context.with_rendering_scope(component_type.get_additional_scope())
loaded_params = self.get_params(context, component_schema) if component_schema else None
return [component_type.load(loaded_params, context)]


@record
class ComponentFolder(ComponentDeclNode):
path: Path
sub_decls: Sequence[Union[YamlComponentDecl, "ComponentFolder"]]

def load(self, context: ComponentLoadContext) -> Sequence[Component]:
components = []
for sub_decl in self.sub_decls:
sub_context = context.for_decl_node(sub_decl)
components.extend(sub_decl.load(sub_context))
return components


def path_to_decl_node(path: Path) -> Optional[ComponentDeclNode]:
# right now, we only support two types of components, both of which are folders
Expand All @@ -75,10 +132,8 @@ def path_to_decl_node(path: Path) -> Optional[ComponentDeclNode]:
if not path.is_dir():
return None

component_path = path / "component.yaml"

if component_path.exists():
return YamlComponentDecl.from_path(component_path)
if YamlComponentDecl.exists_at(path):
return YamlComponentDecl.from_path(path)

subs = []
for subpath in path.iterdir():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import importlib
import importlib.util
import inspect
from collections.abc import Mapping, Sequence
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Optional

from dagster._utils.warnings import suppress_dagster_warnings
Expand All @@ -13,8 +9,6 @@
ComponentLoadContext,
ComponentTypeRegistry,
TemplatedValueResolver,
get_component_type_name,
is_registered_component_type,
)
from dagster_components.core.component_decl_builder import (
ComponentDeclNode,
Expand All @@ -28,19 +22,6 @@
from dagster._core.definitions.definitions_class import Definitions


def load_module_from_path(module_name, path) -> ModuleType:
# Create a spec from the file path
spec = importlib.util.spec_from_file_location(module_name, path)
if spec is None:
raise ImportError(f"Cannot create a module spec from path: {path}")

# Create and load the module
module = importlib.util.module_from_spec(spec)
assert spec.loader, "Must have a loader"
spec.loader.exec_module(module)
return module


def resolve_decl_node_to_yaml_decls(decl: ComponentDeclNode) -> list[YamlComponentDecl]:
if isinstance(decl, YamlComponentDecl):
return [decl]
Expand All @@ -53,57 +34,12 @@ def resolve_decl_node_to_yaml_decls(decl: ComponentDeclNode) -> list[YamlCompone
raise NotImplementedError(f"Unknown component type {decl}")


def load_components_from_context(context: ComponentLoadContext) -> Sequence[Component]:
node = context.decl_node
if isinstance(node, YamlComponentDecl):
component_type = component_type_from_yaml_decl(context.registry, node)
component_schema = component_type.get_schema()
context = context.with_rendering_scope(component_type.get_additional_scope())
loaded_params = node.get_params(context, component_schema) if component_schema else None
return [component_type.load(loaded_params, context)]
elif isinstance(node, ComponentFolder):
components = []
for sub_decl in node.sub_decls:
components.extend(load_components_from_context(context.for_decl_node(sub_decl)))
return components

raise NotImplementedError(f"Unknown component type {node}")


def component_type_from_yaml_decl(
registry: ComponentTypeRegistry, decl_node: YamlComponentDecl
) -> type[Component]:
parsed_defs = decl_node.component_file_model
if parsed_defs.type.startswith("."):
component_registry_key = parsed_defs.type[1:]

# Iterate over Python files in the folder
for py_file in decl_node.path.glob("*.py"):
module_name = py_file.stem

module = load_module_from_path(module_name, decl_node.path / f"{module_name}.py")

for _name, obj in inspect.getmembers(module, inspect.isclass):
assert isinstance(obj, type)
if (
is_registered_component_type(obj)
and get_component_type_name(obj) == component_registry_key
):
return obj

raise Exception(
f"Could not find component type {component_registry_key} in {decl_node.path}"
)

return registry.get(parsed_defs.type)


def build_components_from_component_folder(
context: ComponentLoadContext, path: Path
) -> Sequence[Component]:
component_folder = path_to_decl_node(path)
assert isinstance(component_folder, ComponentFolder)
return load_components_from_context(context.for_decl_node(component_folder))
return component_folder.load(context.for_decl_node(component_folder))


def build_defs_from_component_path(
Expand All @@ -122,7 +58,7 @@ def build_defs_from_component_path(
decl_node=decl_node,
templated_value_resolver=TemplatedValueResolver.default(),
)
components = load_components_from_context(context)
components = decl_node.load(context)
return defs_from_components(resources=resources, context=context, components=components)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Any, Optional

from dagster._core.definitions.asset_key import AssetKey
Expand Down Expand Up @@ -121,3 +122,16 @@ def get_automation_condition(self, obj: Any) -> Optional[AutomationCondition]:
)

return WrappedTranslator


def load_module_from_path(module_name, path) -> ModuleType:
# Create a spec from the file path
spec = importlib.util.spec_from_file_location(module_name, path)
if spec is None:
raise ImportError(f"Cannot create a module spec from path: {path}")

# Create and load the module
module = importlib.util.module_from_spec(spec)
assert spec.loader, "Must have a loader"
spec.loader.exec_module(module)
return module