Skip to content

Commit

Permalink
Fixes updating configclass parameter with a list of objects (#847)
Browse files Browse the repository at this point in the history
# Description

When configclass dicts are nested inside lists, the list is treated as
an Iterable object and assigned directly to the outer configclass when
updating configclass data with dicts. This overwrites the configclass
object in the list with a dict object and causes undesired behavior.

This change checks for nested dictionaries inside Iterables and updates
the values inside the dictionary individually without overwiting the
full Iterable.

Fixes #843

## Type of change

- Bug fix (non-breaking change which fixes an issue)


## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
  • Loading branch information
kellyguo11 authored Aug 22, 2024
1 parent d906c4a commit ad4ec6e
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 11 deletions.
10 changes: 10 additions & 0 deletions source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ def update_class_from_dict(obj, data: dict[str, Any], _ns: str = "") -> None:
)
if isinstance(obj_mem, tuple):
value = tuple(value)
else:
set_obj = True
# recursively call if iterable contains dictionaries
for i in range(len(obj_mem)):
if isinstance(value[i], dict):
update_class_from_dict(obj_mem[i], value[i], _ns=key_ns)
set_obj = False
# do not set value to obj, otherwise it overwrites the cfg class with the dict
if not set_obj:
continue
elif callable(obj_mem):
# update function name
value = string_to_callable(value)
Expand Down
35 changes: 28 additions & 7 deletions source/extensions/omni.isaac.lab/test/utils/test_configclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,11 @@ def set_a(self, a: int):


@configclass
class NestedDictCfg:
"""Dummy configuration class with nested dictionaries."""
class NestedDictAndListCfg:
"""Dummy configuration class with nested dictionaries and lists."""

dict_1: dict = {"dict_2": {"func": dummy_function1}}
list_1: list[EnvCfg] = [EnvCfg(), EnvCfg()]


"""
Expand Down Expand Up @@ -341,10 +342,14 @@ class NestedDictCfg:
"device_id": 0,
}

basic_demo_cfg_nested_dict = {
basic_demo_cfg_nested_dict_and_list = {
"dict_1": {
"dict_2": {"func": dummy_function2},
},
"list_1": [
{"num_envs": 23, "episode_length": 3000, "viewer": {"eye": [5.0, 5.0, 5.0], "lookat": [0.0, 0.0, 0.0]}},
{"num_envs": 24, "episode_length": 2000, "viewer": {"eye": [6.0, 6.0, 6.0], "lookat": [0.0, 0.0, 0.0]}},
],
}

basic_demo_post_init_cfg_correct = {
Expand Down Expand Up @@ -456,6 +461,10 @@ def test_config_update_dict(self):
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_correct)

# check types are also correct
self.assertIsInstance(cfg.env.viewer, ViewerCfg)
self.assertIsInstance(cfg.env.viewer.eye, tuple)

def test_config_update_dict_with_none(self):
"""Test updating configclass using a dictionary that contains None."""
cfg = BasicDemoCfg()
Expand All @@ -464,11 +473,23 @@ def test_config_update_dict_with_none(self):
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_with_none_correct)

def test_config_update_nested_dict(self):
"""Test updating configclass with sub-dictionnaries."""
cfg = NestedDictCfg()
cfg_dict = {"dict_1": {"dict_2": {"func": "__main__:dummy_function2"}}}
"""Test updating configclass with sub-dictionaries."""
cfg = NestedDictAndListCfg()
cfg_dict = {
"dict_1": {"dict_2": {"func": "__main__:dummy_function2"}},
"list_1": [
{"num_envs": 23, "episode_length": 3000, "viewer": {"eye": [5.0, 5.0, 5.0]}},
{"num_envs": 24, "viewer": {"eye": [6.0, 6.0, 6.0]}},
],
}
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_nested_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_nested_dict_and_list)

# check types are also correct
self.assertIsInstance(cfg.list_1[0], EnvCfg)
self.assertIsInstance(cfg.list_1[1], EnvCfg)
self.assertIsInstance(cfg.list_1[0].viewer, ViewerCfg)
self.assertIsInstance(cfg.list_1[1].viewer, ViewerCfg)

def test_config_update_dict_using_internal(self):
"""Test updating configclass from a dictionary using configclass method."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def wrapper(*args, **kwargs):
# register the task to Hydra
env_cfg, agent_cfg = register_task_to_hydra(task_name, agent_cfg_entry_point)

# define thr new Hydra main function
# define the new Hydra main function
@hydra.main(config_path=None, config_name=task_name, version_base="1.3")
def hydra_main(hydra_env_cfg: DictConfig, env_cfg=env_cfg, agent_cfg=agent_cfg):
# convert to a native dictionary
Expand Down
62 changes: 59 additions & 3 deletions source/extensions/omni.isaac.lab_tasks/test/test_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,50 @@

"""Rest everything follows."""


import functools
import unittest
from collections.abc import Callable

import hydra
from hydra import compose, initialize
from omegaconf import OmegaConf

from omni.isaac.lab.utils import replace_strings_with_slices

import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.hydra import register_task_to_hydra


def hydra_task_config_test(task_name: str, agent_cfg_entry_point: str) -> Callable:
"""Copied from hydra.py hydra_task_config, since hydra.main requires a single point of entry,
which will not work with multiple tests. Here, we replace hydra.main with hydra initialize
and compose."""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# register the task to Hydra
env_cfg, agent_cfg = register_task_to_hydra(task_name, agent_cfg_entry_point)

# replace hydra.main with initialize and compose
with initialize(config_path=None, version_base="1.3"):
hydra_env_cfg = compose(config_name=task_name, overrides=sys.argv[1:])
# convert to a native dictionary
hydra_env_cfg = OmegaConf.to_container(hydra_env_cfg, resolve=True)
# replace string with slices because OmegaConf does not support slices
hydra_env_cfg = replace_strings_with_slices(hydra_env_cfg)
# update the configs with the Hydra command line arguments
env_cfg.from_dict(hydra_env_cfg["env"])
if isinstance(agent_cfg, dict):
agent_cfg = hydra_env_cfg["agent"]
else:
agent_cfg.from_dict(hydra_env_cfg["agent"])
# call the original function
func(env_cfg, agent_cfg, *args, **kwargs)

return wrapper

return decorator


class TestHydra(unittest.TestCase):
Expand All @@ -39,7 +78,7 @@ def test_hydra(self):
"agent.max_iterations=3", # test simple agent modification
]

@hydra_task_config("Isaac-Velocity-Flat-H1-v0", "rsl_rl_cfg_entry_point")
@hydra_task_config_test("Isaac-Velocity-Flat-H1-v0", "rsl_rl_cfg_entry_point")
def main(env_cfg, agent_cfg, self):
# env
self.assertEqual(env_cfg.decimation, 42)
Expand All @@ -50,6 +89,23 @@ def main(env_cfg, agent_cfg, self):
self.assertEqual(agent_cfg.max_iterations, 3)

main(self)
# clean up
sys.argv = [sys.argv[0]]
hydra.core.global_hydra.GlobalHydra.instance().clear()

def test_nested_iterable_dict(self):
"""Test the hydra configuration system when dict is nested in an Iterable."""

@hydra_task_config_test("Isaac-Lift-Cube-Franka-v0", "rsl_rl_cfg_entry_point")
def main(env_cfg, agent_cfg, self):
# env
self.assertEqual(env_cfg.scene.ee_frame.target_frames[0].name, "end_effector")
self.assertEqual(env_cfg.scene.ee_frame.target_frames[0].offset.pos[2], 0.1034)

main(self)
# clean up
sys.argv = [sys.argv[0]]
hydra.core.global_hydra.GlobalHydra.instance().clear()


if __name__ == "__main__":
Expand Down

0 comments on commit ad4ec6e

Please sign in to comment.