From e94ac7686cceb5678f42f23f80226af31afc1ddb Mon Sep 17 00:00:00 2001 From: Michael Ball Date: Tue, 7 Feb 2023 10:37:20 +0100 Subject: [PATCH] . --- hydra/_internal/instantiate/_instantiate2.py | 20 +++++++++++++------- tests/instantiate/__init__.py | 5 +++++ tests/instantiate/test_instantiate.py | 11 +++++++++++ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index 2f09ece868b..5d6823b445c 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -354,18 +354,24 @@ def instantiate_node( ): dict_items = {} for key, value in node.items(): - # list items inherits recursive flag from the containing dict. - dict_items[key] = instantiate_node( - value, convert=convert, recursive=recursive - ) + if recursive: + # list items inherits recursive flag from the containing dict. + dict_items[key] = instantiate_node( + value, convert=convert, recursive=recursive + ) + else: + dict_items[key] = value return dict_items else: # Otherwise use DictConfig and resolve interpolations lazily. cfg = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in node.items(): - cfg[key] = instantiate_node( - value, convert=convert, recursive=recursive - ) + if recursive: + cfg[key] = instantiate_node( + value, convert=convert, recursive=recursive + ) + else: + cfg[key] = value cfg._set_parent(node) cfg._metadata.object_type = node._metadata.object_type if convert == ConvertMode.OBJECT: diff --git a/tests/instantiate/__init__.py b/tests/instantiate/__init__.py index 257d117d711..a73c30f5b37 100644 --- a/tests/instantiate/__init__.py +++ b/tests/instantiate/__init__.py @@ -420,6 +420,11 @@ class NestedConf: b: Any = field(default_factory=lambda: User(name="b", age=2)) +@dataclass +class NestedConfNoTarget: + a: Any = field(default_factory=lambda: SimpleClassDefaultPrimitiveConf) + + def recisinstance(got: Any, expected: Any) -> bool: """Compare got with expected type, recursively on dict and list.""" if not isinstance(got, type(expected)): diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index cde77d8af78..87a32fa9df9 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -33,6 +33,7 @@ Mapping, MappingConf, NestedConf, + NestedConfNoTarget, NestingClass, OuterClass, Parameters, @@ -2096,3 +2097,13 @@ class DictValuesConf: cfg = OmegaConf.structured(DictValuesConf) obj = instantiate_func(config=cfg) assert obj.d is None + + + +def test_non_target_recursive(instantiate_func: Any) -> None: + cfg = OmegaConf.structured(NestedConfNoTarget) + obj = instantiate_func(config=cfg, _recursive_=False) + assert isinstance(obj.a, DictConfig) + + obj = instantiate_func(config=cfg, _recursive_=True) + assert isinstance(obj.a, SimpleClass)