Skip to content

Commit

Permalink
update partial thing
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed May 9, 2024
1 parent 8830211 commit 3844b89
Show file tree
Hide file tree
Showing 4 changed files with 376 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@
ConfigType,
Noneable,
)
from dagster._config.post_process import resolve_defaults
from dagster._config.field_utils import _ConfigHasFields
from dagster._config.source import BoolSource, IntSource, StringSource
from dagster._config.validate import validate_config
from dagster._core.definitions.definition_config_schema import (
DefinitionConfigSchema,
)
from dagster._core.errors import (
DagsterInvalidConfigDefinitionError,
DagsterInvalidConfigError,
DagsterInvalidDefinitionError,
DagsterInvalidPythonicConfigDefinitionError,
)
Expand Down Expand Up @@ -59,44 +57,56 @@ class cached_property:


# This is from https://github.com/dagster-io/dagster/pull/11470
def _apply_defaults_to_schema_field(field: Field, additional_default_values: Any) -> Field:
# This work by validating the top-level config and then
# just setting it at that top-level field. Config fields
# can actually take nested values so we only need to set it
# at a single level

evr = validate_config(field.config_type, additional_default_values)

if not evr.success:
raise DagsterInvalidConfigError(
"Incorrect values passed to .configured",
evr.errors,
additional_default_values,
def _apply_defaults_to_schema_field(old_field: Field, additional_default_values: Any) -> Field:
"""Given a config Field and a set of default values (usually a dictionary or raw default value),
return a new Field with the default values applied to it (and recursively to any sub-fields).
"""
# If the field has subfields and the default value is a dictionary, iterate
# over the subfields and apply the defaults to them.
if isinstance(old_field.config_type, _ConfigHasFields) and isinstance(
additional_default_values, dict
):
updated_sub_fields = {
k: _apply_defaults_to_schema_field(
sub_field, additional_default_values.get(k, FIELD_NO_DEFAULT_PROVIDED)
)
for k, sub_field in old_field.config_type.fields.items()
}

# We also apply a new default value to the field if all of its subfields have defaults
new_default = (
old_field.default_value if old_field.default_provided else FIELD_NO_DEFAULT_PROVIDED
)
if all(
sub_field.default_provided or not sub_field.is_required
for sub_field in updated_sub_fields.values()
):
new_default = {
**additional_default_values,
**{k: v.default_value for k, v in updated_sub_fields.items() if v.default_provided},
}

if field.default_provided:
# In the case where there is already a default config value
# we can apply "additional" defaults by actually invoking
# the config machinery. Meaning we pass the new_additional_default_values
# and then resolve the existing defaults over them. This preserves the default
# values that are not specified in new_additional_default_values and then
# applies the new value as the default value of the field in question.
defaults_processed_evr = resolve_defaults(field.config_type, additional_default_values)
check.invariant(
defaults_processed_evr.success,
"Since validation passed, this should always work.",
return Field(
config=old_field.config_type.__class__(fields=updated_sub_fields),
default_value=new_default,
is_required=new_default == FIELD_NO_DEFAULT_PROVIDED and old_field.is_required,
description=old_field.description,
)
default_to_pass = defaults_processed_evr.value
return copy_with_default(field, default_to_pass)
else:
return copy_with_default(field, additional_default_values)
return copy_with_default(old_field, additional_default_values)


def copy_with_default(old_field: Field, new_config_value: Any) -> Field:
"""Copies a Field, but replaces the default value with the provided value.
Also updates the is_required flag depending on whether the new config value is
actually specified.
"""
return Field(
config=old_field.config_type,
default_value=new_config_value,
is_required=False,
default_value=old_field.default_value
if new_config_value == FIELD_NO_DEFAULT_PROVIDED and old_field.default_provided
else new_config_value,
is_required=new_config_value == FIELD_NO_DEFAULT_PROVIDED and old_field.is_required,
description=old_field.description,
)

Expand Down
31 changes: 26 additions & 5 deletions python_modules/dagster/dagster/_config/pythonic_config/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
ConfiguredDefinitionConfigSchema,
DefinitionConfigSchema,
)
from dagster._core.errors import DagsterInvalidConfigError
from dagster._core.errors import DagsterInvalidConfigError, DagsterInvalidDefinitionError
from dagster._core.execution.context.init import InitResourceContext, build_init_resource_context
from dagster._model.pydantic_compat_layer import (
model_fields,
Expand Down Expand Up @@ -445,6 +445,13 @@ def configure_at_launch(cls: "Type[T_Self]", **kwargs) -> "PartialResource[T_Sel
"""
return PartialResource(cls, data=kwargs)

@classmethod
def partial(cls: "Type[T_Self]", **kwargs) -> "PartialResource[T_Self]":
"""Returns a partially initialized copy of the resource, with remaining config fields
set at runtime.
"""
return PartialResource(cls, data=kwargs, is_partial=True)

def _with_updated_values(
self, values: Optional[Mapping[str, Any]]
) -> "ConfigurableResourceFactory[TResValue]":
Expand Down Expand Up @@ -759,8 +766,18 @@ def __init__(
self,
resource_cls: Type[ConfigurableResourceFactory[TResValue]],
data: Dict[str, Any],
is_partial: bool = False,
):
resource_pointers, _data_without_resources = separate_resource_params(resource_cls, data)
resource_pointers, data_without_resources = separate_resource_params(resource_cls, data)

if not is_partial and data_without_resources:
resource_name = resource_cls.__name__
parameter_names = list(data_without_resources.keys())
raise DagsterInvalidDefinitionError(
f"'{resource_name}.configure_at_launch' was called but non-resource parameters"
f" were passed: {parameter_names}. Did you mean to call '{resource_name}.partial'"
" instead?"
)

super().__init__(data=data, resource_cls=resource_cls) # type: ignore # extends BaseModel, takes kwargs

Expand All @@ -773,15 +790,19 @@ def resource_fn(context: InitResourceContext):
) # So that collisions are resolved in favor of the latest provided run config
return instantiated._get_initialize_and_run_fn()(context) # noqa: SLF001

schema = infer_schema_from_config_class(
resource_cls, fields_to_omit=set(resource_pointers.keys())
)
resolved_config_dict = config_dictionary_from_values(data_without_resources, schema)
curried_schema = _curry_config_schema(schema, resolved_config_dict)

self._state__internal__ = PartialResourceState(
# We keep track of any resources we depend on which are not fully configured
# so that we can retrieve them at runtime
nested_partial_resources={
k: v for k, v in resource_pointers.items() if (not _is_fully_configured(v))
},
config_schema=infer_schema_from_config_class(
resource_cls, fields_to_omit=set(resource_pointers.keys())
),
config_schema=curried_schema.as_field(),
resource_fn=resource_fn,
description=resource_cls.__doc__,
nested_resources={k: v for k, v in resource_pointers.items()},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def asset_with_resource(context, my_resource: MyResource):

result_two = materialize(
[asset_with_resource],
resources={"my_resource": MyResource.configure_at_launch(my_enum=AnotherEnum.A)},
resources={"my_resource": MyResource.partial(my_enum=AnotherEnum.A)},
run_config={"resources": {"my_resource": {"config": {"my_enum": "B"}}}},
)

Expand Down
Loading

0 comments on commit 3844b89

Please sign in to comment.