Skip to content

Commit ad4ec6e

Browse files
authored
Fixes updating configclass parameter with a list of objects (#847)
# Description When configclass dicts are nested inside lists, the list is treated as an Iterable object and assigned directly to the outer configclass when updating configclass data with dicts. This overwrites the configclass object in the list with a dict object and causes undesired behavior. This change checks for nested dictionaries inside Iterables and updates the values inside the dictionary individually without overwiting the full Iterable. Fixes #843 ## Type of change - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there
1 parent d906c4a commit ad4ec6e

File tree

4 files changed

+98
-11
lines changed

4 files changed

+98
-11
lines changed

source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ def update_class_from_dict(obj, data: dict[str, Any], _ns: str = "") -> None:
9595
)
9696
if isinstance(obj_mem, tuple):
9797
value = tuple(value)
98+
else:
99+
set_obj = True
100+
# recursively call if iterable contains dictionaries
101+
for i in range(len(obj_mem)):
102+
if isinstance(value[i], dict):
103+
update_class_from_dict(obj_mem[i], value[i], _ns=key_ns)
104+
set_obj = False
105+
# do not set value to obj, otherwise it overwrites the cfg class with the dict
106+
if not set_obj:
107+
continue
98108
elif callable(obj_mem):
99109
# update function name
100110
value = string_to_callable(value)

source/extensions/omni.isaac.lab/test/utils/test_configclass.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,11 @@ def set_a(self, a: int):
298298

299299

300300
@configclass
301-
class NestedDictCfg:
302-
"""Dummy configuration class with nested dictionaries."""
301+
class NestedDictAndListCfg:
302+
"""Dummy configuration class with nested dictionaries and lists."""
303303

304304
dict_1: dict = {"dict_2": {"func": dummy_function1}}
305+
list_1: list[EnvCfg] = [EnvCfg(), EnvCfg()]
305306

306307

307308
"""
@@ -341,10 +342,14 @@ class NestedDictCfg:
341342
"device_id": 0,
342343
}
343344

344-
basic_demo_cfg_nested_dict = {
345+
basic_demo_cfg_nested_dict_and_list = {
345346
"dict_1": {
346347
"dict_2": {"func": dummy_function2},
347348
},
349+
"list_1": [
350+
{"num_envs": 23, "episode_length": 3000, "viewer": {"eye": [5.0, 5.0, 5.0], "lookat": [0.0, 0.0, 0.0]}},
351+
{"num_envs": 24, "episode_length": 2000, "viewer": {"eye": [6.0, 6.0, 6.0], "lookat": [0.0, 0.0, 0.0]}},
352+
],
348353
}
349354

350355
basic_demo_post_init_cfg_correct = {
@@ -456,6 +461,10 @@ def test_config_update_dict(self):
456461
update_class_from_dict(cfg, cfg_dict)
457462
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_correct)
458463

464+
# check types are also correct
465+
self.assertIsInstance(cfg.env.viewer, ViewerCfg)
466+
self.assertIsInstance(cfg.env.viewer.eye, tuple)
467+
459468
def test_config_update_dict_with_none(self):
460469
"""Test updating configclass using a dictionary that contains None."""
461470
cfg = BasicDemoCfg()
@@ -464,11 +473,23 @@ def test_config_update_dict_with_none(self):
464473
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_with_none_correct)
465474

466475
def test_config_update_nested_dict(self):
467-
"""Test updating configclass with sub-dictionnaries."""
468-
cfg = NestedDictCfg()
469-
cfg_dict = {"dict_1": {"dict_2": {"func": "__main__:dummy_function2"}}}
476+
"""Test updating configclass with sub-dictionaries."""
477+
cfg = NestedDictAndListCfg()
478+
cfg_dict = {
479+
"dict_1": {"dict_2": {"func": "__main__:dummy_function2"}},
480+
"list_1": [
481+
{"num_envs": 23, "episode_length": 3000, "viewer": {"eye": [5.0, 5.0, 5.0]}},
482+
{"num_envs": 24, "viewer": {"eye": [6.0, 6.0, 6.0]}},
483+
],
484+
}
470485
update_class_from_dict(cfg, cfg_dict)
471-
self.assertDictEqual(asdict(cfg), basic_demo_cfg_nested_dict)
486+
self.assertDictEqual(asdict(cfg), basic_demo_cfg_nested_dict_and_list)
487+
488+
# check types are also correct
489+
self.assertIsInstance(cfg.list_1[0], EnvCfg)
490+
self.assertIsInstance(cfg.list_1[1], EnvCfg)
491+
self.assertIsInstance(cfg.list_1[0].viewer, ViewerCfg)
492+
self.assertIsInstance(cfg.list_1[1].viewer, ViewerCfg)
472493

473494
def test_config_update_dict_using_internal(self):
474495
"""Test updating configclass from a dictionary using configclass method."""

source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/hydra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def wrapper(*args, **kwargs):
7474
# register the task to Hydra
7575
env_cfg, agent_cfg = register_task_to_hydra(task_name, agent_cfg_entry_point)
7676

77-
# define thr new Hydra main function
77+
# define the new Hydra main function
7878
@hydra.main(config_path=None, config_name=task_name, version_base="1.3")
7979
def hydra_main(hydra_env_cfg: DictConfig, env_cfg=env_cfg, agent_cfg=agent_cfg):
8080
# convert to a native dictionary

source/extensions/omni.isaac.lab_tasks/test/test_hydra.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,50 @@
1616

1717
"""Rest everything follows."""
1818

19-
19+
import functools
2020
import unittest
21+
from collections.abc import Callable
22+
23+
import hydra
24+
from hydra import compose, initialize
25+
from omegaconf import OmegaConf
26+
27+
from omni.isaac.lab.utils import replace_strings_with_slices
2128

2229
import omni.isaac.lab_tasks # noqa: F401
23-
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
30+
from omni.isaac.lab_tasks.utils.hydra import register_task_to_hydra
31+
32+
33+
def hydra_task_config_test(task_name: str, agent_cfg_entry_point: str) -> Callable:
34+
"""Copied from hydra.py hydra_task_config, since hydra.main requires a single point of entry,
35+
which will not work with multiple tests. Here, we replace hydra.main with hydra initialize
36+
and compose."""
37+
38+
def decorator(func):
39+
@functools.wraps(func)
40+
def wrapper(*args, **kwargs):
41+
# register the task to Hydra
42+
env_cfg, agent_cfg = register_task_to_hydra(task_name, agent_cfg_entry_point)
43+
44+
# replace hydra.main with initialize and compose
45+
with initialize(config_path=None, version_base="1.3"):
46+
hydra_env_cfg = compose(config_name=task_name, overrides=sys.argv[1:])
47+
# convert to a native dictionary
48+
hydra_env_cfg = OmegaConf.to_container(hydra_env_cfg, resolve=True)
49+
# replace string with slices because OmegaConf does not support slices
50+
hydra_env_cfg = replace_strings_with_slices(hydra_env_cfg)
51+
# update the configs with the Hydra command line arguments
52+
env_cfg.from_dict(hydra_env_cfg["env"])
53+
if isinstance(agent_cfg, dict):
54+
agent_cfg = hydra_env_cfg["agent"]
55+
else:
56+
agent_cfg.from_dict(hydra_env_cfg["agent"])
57+
# call the original function
58+
func(env_cfg, agent_cfg, *args, **kwargs)
59+
60+
return wrapper
61+
62+
return decorator
2463

2564

2665
class TestHydra(unittest.TestCase):
@@ -39,7 +78,7 @@ def test_hydra(self):
3978
"agent.max_iterations=3", # test simple agent modification
4079
]
4180

42-
@hydra_task_config("Isaac-Velocity-Flat-H1-v0", "rsl_rl_cfg_entry_point")
81+
@hydra_task_config_test("Isaac-Velocity-Flat-H1-v0", "rsl_rl_cfg_entry_point")
4382
def main(env_cfg, agent_cfg, self):
4483
# env
4584
self.assertEqual(env_cfg.decimation, 42)
@@ -50,6 +89,23 @@ def main(env_cfg, agent_cfg, self):
5089
self.assertEqual(agent_cfg.max_iterations, 3)
5190

5291
main(self)
92+
# clean up
93+
sys.argv = [sys.argv[0]]
94+
hydra.core.global_hydra.GlobalHydra.instance().clear()
95+
96+
def test_nested_iterable_dict(self):
97+
"""Test the hydra configuration system when dict is nested in an Iterable."""
98+
99+
@hydra_task_config_test("Isaac-Lift-Cube-Franka-v0", "rsl_rl_cfg_entry_point")
100+
def main(env_cfg, agent_cfg, self):
101+
# env
102+
self.assertEqual(env_cfg.scene.ee_frame.target_frames[0].name, "end_effector")
103+
self.assertEqual(env_cfg.scene.ee_frame.target_frames[0].offset.pos[2], 0.1034)
104+
105+
main(self)
106+
# clean up
107+
sys.argv = [sys.argv[0]]
108+
hydra.core.global_hydra.GlobalHydra.instance().clear()
53109

54110

55111
if __name__ == "__main__":

0 commit comments

Comments
 (0)