Skip to content

Commit

Permalink
[pythonic resources][wip] Partial config support for resources
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed Mar 10, 2023
1 parent 2421fd0 commit 83041e6
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

from dagster._annotations import experimental
from dagster._config.config_type import Array, ConfigFloatInstance, ConfigType
from dagster._config.field_utils import config_dictionary_from_values
from dagster._config.field_utils import _ConfigHasFields, config_dictionary_from_values
from dagster._config.post_process import resolve_defaults
from dagster._config.source import BoolSource, IntSource, StringSource
from dagster._config.structured_config.typing_utils import TypecheckAllowPartialResourceInitParams
from dagster._config.type_printer import print_config_type_to_string
from dagster._config.validate import process_config, validate_config
from dagster._core.definitions.definition_config_schema import (
ConfiguredDefinitionConfigSchema,
DefinitionConfigSchema,
)
from dagster._core.errors import DagsterInvalidConfigError
from dagster._core.errors import DagsterInvalidConfigError, DagsterInvalidDefinitionError
from dagster._core.execution.context.init import InitResourceContext

try:
Expand Down Expand Up @@ -145,21 +146,45 @@ class Config:
"""


def _recursively_apply_field_defaults(old_field: Field, additional_default_values: Any) -> Field:
if isinstance(old_field.config_type, _ConfigHasFields) and isinstance(
additional_default_values, dict
):
print("DRILLING DOWN", [k for k in old_field.config_type.fields], additional_default_values)
updated_sub_fields = {
k: _recursively_apply_field_defaults(
sub_field, additional_default_values.get(k, FIELD_NO_DEFAULT_PROVIDED)
)
for k, sub_field in old_field.config_type.fields.items()
}
return Field(
config=old_field.config_type.__class__(fields=updated_sub_fields),
default_value=old_field.default_value
if old_field.default_provided
else FIELD_NO_DEFAULT_PROVIDED,
is_required=not old_field.default_provided and old_field.is_required,
description=old_field.description,
)
else:
print("SETTING", additional_default_values)
return copy_with_default(old_field, additional_default_values)


# This is from https://github.com/dagster-io/dagster/pull/11470
def _apply_defaults_to_schema_field(field: Field, additional_default_values: Any) -> Field:
# This work by validating the top-level config and then
# just setting it at that top-level field. Config fields
# can actually take nested values so we only need to set it
# at a single level

evr = validate_config(field.config_type, additional_default_values)
# evr = validate_config(field.config_type, additional_default_values)

if not evr.success:
raise DagsterInvalidConfigError(
"Incorrect values passed to .configured",
evr.errors,
additional_default_values,
)
# if not evr.success:
# raise DagsterInvalidConfigError(
# "Incorrect values passed to .configured",
# evr.errors,
# additional_default_values,
# )

if field.default_provided:
# In the case where there is already a default config value
Expand All @@ -173,16 +198,27 @@ def _apply_defaults_to_schema_field(field: Field, additional_default_values: Any
defaults_processed_evr.success, "Since validation passed, this should always work."
)
default_to_pass = defaults_processed_evr.value
return copy_with_default(field, default_to_pass)

managed = _recursively_apply_field_defaults(field, additional_default_values)
evr = validate_config(field.config_type, additional_default_values)

if not evr.success:
return managed
return copy_with_default(managed, default_to_pass)
else:
return copy_with_default(field, additional_default_values)
managed = _recursively_apply_field_defaults(field, additional_default_values)
evr = validate_config(field.config_type, additional_default_values)

if not evr.success:
return managed
return copy_with_default(managed, additional_default_values)


def copy_with_default(old_field: Field, new_config_value: Any) -> Field:
return Field(
config=old_field.config_type,
default_value=new_config_value,
is_required=False,
is_required=new_config_value == FIELD_NO_DEFAULT_PROVIDED and old_field.is_required,
description=old_field.description,
)

Expand Down Expand Up @@ -440,6 +476,14 @@ def configure_at_launch(cls: "Type[Self]", **kwargs) -> "PartialResource[Self]":
"""
return PartialResource(cls, data=kwargs)

@classmethod
def partial(cls: "Type[Self]", **kwargs) -> "PartialResource[Self]":
"""
Returns a partially initialized copy of the resource, with remaining config fields
set at runtime.
"""
return PartialResource(cls, data=kwargs, is_partial=True)

def _with_updated_values(self, values: Mapping[str, Any]) -> "ConfigurableResource[TResValue]":
"""
Returns a new instance of the resource with the given values.
Expand Down Expand Up @@ -544,9 +588,20 @@ class PartialResource(
data: Dict[str, Any]
resource_cls: Type[ConfigurableResource[TResValue]]

def __init__(self, resource_cls: Type[ConfigurableResource[TResValue]], data: Dict[str, Any]):
def __init__(
self,
resource_cls: Type[ConfigurableResource[TResValue]],
data: Dict[str, Any],
is_partial: bool = False,
):
resource_pointers, data_without_resources = separate_resource_params(data)

if not is_partial and data_without_resources:
raise DagsterInvalidDefinitionError(
f"Resource {resource_cls.__name__} is not marked as partial, but was passed "
f"non-resource parameters: {list(data_without_resources.keys())}."
)

MakeConfigCacheable.__init__(self, data=data, resource_cls=resource_cls) # type: ignore # extends BaseModel, takes kwargs

# We keep track of any resources we depend on which are not fully configured
Expand All @@ -559,14 +614,17 @@ def __init__(self, resource_cls: Type[ConfigurableResource[TResValue]], data: Di
resource_cls, fields_to_omit=set(resource_pointers.keys())
)

resolved_config_dict = config_dictionary_from_values(data_without_resources, schema)
curried_schema = _curry_config_schema(schema, resolved_config_dict)

def resource_fn(context: InitResourceContext):
instantiated = resource_cls(**context.resource_config, **data)
instantiated = resource_cls(**context.resource_config, **(resource_pointers))
return instantiated.initialize_and_run(context)

ResourceDefinition.__init__(
self,
resource_fn=resource_fn,
config_schema=schema,
config_schema=curried_schema,
description=resource_cls.__doc__,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1330,3 +1330,85 @@ def an_asset(writer: ExtendingResource):
)

assert executed["yes"]


def test_structured_resource_partial_config() -> None:
out_txt = []

class WriterResource(ConfigurableResource):
prefix: str
postfix: str

def output(self, text: str) -> None:
out_txt.append(f"{self.prefix}{text}{self.postfix}")

@asset
def hello_world_asset(writer: WriterResource):
writer.output("hello, world!")

# No params set with partial
defs = Definitions(
assets=[hello_world_asset],
resources={"writer": WriterResource.partial()},
)

assert (
defs.get_implicit_global_asset_job_def()
.execute_in_process({"resources": {"writer": {"config": {"prefix": ">", "postfix": "<"}}}})
.success
)
assert out_txt == [">hello, world!<"]

out_txt.clear()

# One param set as partial
defs = Definitions(
assets=[hello_world_asset],
resources={"writer": WriterResource.partial(prefix="(")},
)

assert (
defs.get_implicit_global_asset_job_def()
.execute_in_process({"resources": {"writer": {"config": {"postfix": ")"}}}})
.success
)
assert out_txt == ["(hello, world!)"]
out_txt.clear()

# Two params set as partial
defs = Definitions(
assets=[hello_world_asset],
resources={"writer": WriterResource.partial(prefix="{", postfix="}")},
)

assert defs.get_implicit_global_asset_job_def().execute_in_process().success
assert out_txt == ["{hello, world!}"]
out_txt.clear()

# Overriding partial param
defs = Definitions(
assets=[hello_world_asset],
resources={"writer": WriterResource.partial(prefix="{")},
)

assert (
defs.get_implicit_global_asset_job_def()
.execute_in_process({"resources": {"writer": {"config": {"prefix": "[", "postfix": "]"}}}})
.success
)
assert out_txt == ["[hello, world!]"]
out_txt.clear()

# Overriding both partial params
defs = Definitions(
assets=[hello_world_asset],
resources={"writer": WriterResource.partial(prefix="<", postfix=">")},
)

assert (
defs.get_implicit_global_asset_job_def()
.execute_in_process({"resources": {"writer": {"config": {"prefix": "*", "postfix": "*"}}}})
.success
)
assert out_txt == ["*hello, world!*"]
out_txt.clear()

0 comments on commit 83041e6

Please sign in to comment.