Skip to content

Commit

Permalink
add comments, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed Apr 12, 2023
1 parent a276124 commit a06b828
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 59 deletions.
108 changes: 56 additions & 52 deletions python_modules/dagster/dagster/_config/pythonic_config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,23 @@ class Config:
"""


def _recursively_apply_field_defaults(old_field: Field, additional_default_values: Any) -> Field:
def _apply_defaults_to_schema_field(old_field: Field, additional_default_values: Any) -> Field:
"""Given a config Field and a set of default values (usually a dictionary or raw default value),
return a new Field with the default values applied to it (and recursively to any sub-fields).
"""
# If the field has subfields and the default value is a dictionary, iterate
# over the subfields and apply the defaults to them.
if isinstance(old_field.config_type, _ConfigHasFields) and isinstance(
additional_default_values, dict
):
updated_sub_fields = {
k: _recursively_apply_field_defaults(
k: _apply_defaults_to_schema_field(
sub_field, additional_default_values.get(k, FIELD_NO_DEFAULT_PROVIDED)
)
for k, sub_field in old_field.config_type.fields.items()
}

# We also apply a new default value to the field if all of its subfields have defaults
new_default = (
old_field.default_value if old_field.default_provided else FIELD_NO_DEFAULT_PROVIDED
)
Expand All @@ -210,26 +217,11 @@ def _recursively_apply_field_defaults(old_field: Field, additional_default_value
return copy_with_default(old_field, additional_default_values)


# This is from https://github.com/dagster-io/dagster/pull/11470
def _apply_defaults_to_schema_field(field: Field, additional_default_values: Any) -> Field:
# This work by validating the top-level config and then
# just setting it at that top-level field. Config fields
# can actually take nested values so we only need to set it
# at a single level

# evr = validate_config(field.config_type, additional_default_values)

# if not evr.success:
# raise DagsterInvalidConfigError(
# "Incorrect values passed to .configured",
# evr.errors,
# additional_default_values,
# )

return _recursively_apply_field_defaults(field, additional_default_values)


def copy_with_default(old_field: Field, new_config_value: Any) -> Field:
"""Copies a Field, but replaces the default value with the provided value.
Also updates the is_required flag depending on whether the new config value is
actually specified.
"""
return Field(
config=old_field.config_type,
default_value=old_field.default_value
Expand Down Expand Up @@ -281,6 +273,11 @@ def _resolve_required_resource_keys_for_resource(
return resource.required_resource_keys


class SeparatedResourceParams(NamedTuple):
resources: Dict[str, Any]
non_resources: Dict[str, Any]


class AllowDelayedDependencies:
_nested_partial_resources: Mapping[str, ResourceDefinition] = {}

Expand Down Expand Up @@ -309,7 +306,7 @@ def _resolve_required_resource_keys(
_resolve_required_resource_keys_for_resource(v, resource_mapping)
)

resources, _ = separate_resource_params(self.__class__, self.__dict__)
resources, _ = self.separate_resource_params(self.__dict__)
for v in resources.values():
nested_resource_required_keys.update(
_resolve_required_resource_keys_for_resource(v, resource_mapping)
Expand All @@ -320,6 +317,10 @@ def _resolve_required_resource_keys(
)
return out

@abstractmethod
def separate_resource_params(self, data: Dict[str, Any]) -> SeparatedResourceParams:
raise NotImplementedError()


class InitResourceContextWithKeyMapping(InitResourceContext):
"""Passes along a mapping from ResourceDefinition id to resource key alongside the
Expand Down Expand Up @@ -484,7 +485,7 @@ def asset_that_uses_database(database: ResourceParam[Database]):
"""

def __init__(self, **data: Any):
resource_pointers, data_without_resources = separate_resource_params(self.__class__, data)
resource_pointers, data_without_resources = self.separate_resource_params(data)

schema = infer_schema_from_config_class(
self.__class__, fields_to_omit=set(resource_pointers.keys())
Expand Down Expand Up @@ -601,7 +602,7 @@ def _resolve_and_update_nested_resources(
}

# Also evaluate any resources that are not partial
resources_to_update, _ = separate_resource_params(self.__class__, self.__dict__)
resources_to_update, _ = self.separate_resource_params(self.__dict__)
resources_to_update = {
attr_name: _call_resource_fn_with_default(resource_def, context)
for attr_name, resource_def in resources_to_update.items()
Expand Down Expand Up @@ -640,6 +641,26 @@ def my_resource(context: InitResourceContext) -> MyResource:
"""
return cls(**context.resource_config or {}).create_resource(context)

@classmethod
def separate_resource_params_on_cls(cls, data: Dict[str, Any]) -> SeparatedResourceParams:
"""Separates out the key/value inputs of fields in a structured config Resource class which
are themselves Resources and those which are not.
"""
resources = {}
non_resources = {}
for k, v in data.items():
field = getattr(cls, "__fields__", {}).get(k)
if field and _is_resource_dependency(field.annotation):
resources[k] = v
elif isinstance(v, ResourceDefinition):
resources[k] = v
else:
non_resources[k] = v
return SeparatedResourceParams(resources=resources, non_resources=non_resources)

def separate_resource_params(self, data: Dict[str, Any]) -> SeparatedResourceParams:
return self.__class__.separate_resource_params_on_cls(data)


class ConfigurableResource(ConfigurableResourceFactory[TResValue]):
"""Base class for Dagster resources that utilize structured config.
Expand Down Expand Up @@ -710,12 +731,16 @@ def __init__(
data: Dict[str, Any],
is_partial: bool = False,
):
resource_pointers, data_without_resources = separate_resource_params(self.__class__, data)
self._resource_cls = resource_cls
resource_pointers, data_without_resources = self.separate_resource_params(data)

if not is_partial and data_without_resources:
resource_name = resource_cls.__name__
parameter_names = list(data_without_resources.keys())
raise DagsterInvalidDefinitionError(
f"Resource {resource_cls.__name__} is not marked as partial, but was passed "
f"non-resource parameters: {list(data_without_resources.keys())}."
f"'{resource_name}.configure_at_launch' was called but non-resource parameters"
f" were passed: {parameter_names}. Did you mean to call '{resource_name}.partial'"
" instead?"
)

MakeConfigCacheable.__init__(self, data=data, resource_cls=resource_cls) # type: ignore # extends BaseModel, takes kwargs
Expand Down Expand Up @@ -743,9 +768,12 @@ def resource_fn(context: InitResourceContext):
config_schema=curried_schema,
description=resource_cls.__doc__,
)

self._resource_cls = resource_cls
self._nested_resources = {k: v for k, v in resource_pointers.items()}

def separate_resource_params(self, data: Dict[str, Any]) -> SeparatedResourceParams:
return self._resource_cls.separate_resource_params_on_cls(data)

@property
def nested_resources(self) -> Mapping[str, ResourceDefinition]:
return self._nested_resources
Expand Down Expand Up @@ -1216,11 +1244,6 @@ def infer_schema_from_config_class(
return Field(config=shape_cls(fields), description=description or docstring)


class SeparatedResourceParams(NamedTuple):
resources: Dict[str, Any]
non_resources: Dict[str, Any]


def _is_resource_dependency(typ: Type) -> bool:
return (
safe_is_subclass(typ, ResourceDependency)
Expand All @@ -1230,25 +1253,6 @@ def _is_resource_dependency(typ: Type) -> bool:
)


def separate_resource_params(
cls: Type[AllowDelayedDependencies], data: Dict[str, Any]
) -> SeparatedResourceParams:
"""Separates out the key/value inputs of fields in a structured config Resource class which
are themselves Resources and those which are not.
"""
resources = {}
non_resources = {}
for k, v in data.items():
field = getattr(cls, "__fields__", {}).get(k)
if field and _is_resource_dependency(field.annotation):
resources[k] = v
elif isinstance(v, ResourceDefinition):
resources[k] = v
else:
non_resources[k] = v
return SeparatedResourceParams(resources=resources, non_resources=non_resources)


def _call_resource_fn_with_default(obj: ResourceDefinition, context: InitResourceContext) -> Any:
if isinstance(obj.config_schema, ConfiguredDefinitionConfigSchema):
value = cast(Dict[str, Any], obj.config_schema.resolve_config({}).value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
ConfigurableResource,
PartialResource,
ResourceWithKeyMapping,
separate_resource_params,
)
from dagster._core.definitions.asset_graph import AssetGraph
from dagster._core.definitions.assets_job import (
Expand Down Expand Up @@ -80,8 +79,8 @@ def _env_vars_from_resource_defaults(resource_def: ResourceDefinition) -> Set[st
if isinstance(resource_def, ResourceWithKeyMapping) and isinstance(
resource_def.inner_resource, (ConfigurableResource, PartialResource)
):
nested_resources = separate_resource_params(
resource_def.inner_resource.__class__, resource_def.inner_resource.__dict__
nested_resources = resource_def.inner_resource.separate_resource_params(
resource_def.inner_resource.__dict__
).resources
for nested_resource in nested_resources.values():
env_vars = env_vars.union(_env_vars_from_resource_defaults(nested_resource))
Expand Down
Loading

0 comments on commit a06b828

Please sign in to comment.