diff --git a/python_modules/dagster/dagster/_config/pythonic_config/config.py b/python_modules/dagster/dagster/_config/pythonic_config/config.py index 28cfb74cd559b..5c995e3919790 100644 --- a/python_modules/dagster/dagster/_config/pythonic_config/config.py +++ b/python_modules/dagster/dagster/_config/pythonic_config/config.py @@ -414,6 +414,7 @@ def infer_schema_from_config_class( check.param_invariant( safe_is_subclass(model_cls, Config), + "model_cls", "Config type annotation must inherit from dagster.Config", ) diff --git a/python_modules/dagster/dagster/_config/pythonic_config/conversion_utils.py b/python_modules/dagster/dagster/_config/pythonic_config/conversion_utils.py index cef5010f92e1b..0cc2da5f35208 100644 --- a/python_modules/dagster/dagster/_config/pythonic_config/conversion_utils.py +++ b/python_modules/dagster/dagster/_config/pythonic_config/conversion_utils.py @@ -101,12 +101,15 @@ def _convert_pydantic_field( if get_origin(pydantic_field.annotation) == Literal: return _convert_typing_literal_field(pydantic_field) - field_type = pydantic_field.annotation + field_type = ( + pydantic_field.annotation + ) # here by just passing the field_type is where we lose any default values if safe_is_subclass(field_type, Config): inferred_field = infer_schema_from_config_class( field_type, description=pydantic_field.description, ) + # but once we have the non-default inferred field, i'm not sure how we can then apply the defaults return inferred_field else: if not pydantic_field.is_required() and not is_closed_python_optional_type(field_type): diff --git a/python_modules/dagster/dagster_tests/core_tests/pythonic_config_tests/test_pythonic_config_types.py b/python_modules/dagster/dagster_tests/core_tests/pythonic_config_tests/test_pythonic_config_types.py index 5072094b81bc4..04f7fe9077463 100644 --- a/python_modules/dagster/dagster_tests/core_tests/pythonic_config_tests/test_pythonic_config_types.py +++ b/python_modules/dagster/dagster_tests/core_tests/pythonic_config_tests/test_pythonic_config_types.py @@ -999,3 +999,41 @@ def echo_job(): d = {"test": "test"} result = echo_job.execute_in_process(resources={"my_resource": ConfigWithAlias(alias_name=d)}) assert result.output_for_node("echo_config") == d + + +def test_nested_config_with_defaults() -> None: + class NestedConfig(Config): + a: str + b: int + + class ConfigClassToConvert(Config): + nested: NestedConfig = Field(default=NestedConfig(a="a", b=1)) + an_int: int = 2 + + fields = ConfigClassToConvert.to_fields_dict() + + assert isinstance(fields, dict) + assert set(fields.keys()) == { + "nested", + "an_int", + } + assert fields["nested"].default_value == {"a": "a", "b": 1} + assert fields["an_int"].default_value == 2 + + class NestedConfigWithDefaults(Config): + a: str = "a" + b: int = 1 + + class ConfigClassToConvert(Config): + nested: NestedConfigWithDefaults + an_int: int = 2 + + fields = ConfigClassToConvert.to_fields_dict() + + assert isinstance(fields, dict) + assert set(fields.keys()) == { + "nested", + "an_int", + } + assert fields["nested"].default_value == {"a": "a", "b": 1} + assert fields["an_int"].default_value == 2