Skip to content

Commit

Permalink
Improve auto wrapper detection implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 21, 2024
1 parent 159d39e commit b9b07eb
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 86 deletions.
80 changes: 37 additions & 43 deletions skrl/envs/wrappers/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,50 +63,44 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
:return: Wrapped environment
:rtype: Wrapper or MultiAgentEnvWrapper
"""
if verbose:
logger.info("Environment class: {}".format(", ".join([str(base).replace("<class '", "").replace("'>", "") \
for base in env.__class__.__bases__])))
if wrapper == "auto":
base_classes = [str(base) for base in env.__class__.__bases__]
if "<class 'omni.isaac.gym.vec_env.vec_env_base.VecEnvBase'>" in base_classes or \
"<class 'omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT'>" in base_classes:
if verbose:
logger.info("Environment wrapper: Omniverse Isaac Gym")
return OmniverseIsaacGymWrapper(env)
elif isinstance(env, gym.core.Env) or isinstance(env, gym.core.Wrapper):
# isaaclab
if hasattr(env, "sim") and hasattr(env, "env_ns"):
if verbose:
logger.info("Environment wrapper: Isaac Lab")
return IsaacLabWrapper(env)
# gym
if verbose:
logger.info("Environment wrapper: Gym")
return GymWrapper(env)
elif isinstance(env, gymnasium.core.Env) or isinstance(env, gymnasium.core.Wrapper):
if verbose:
logger.info("Environment wrapper: Gymnasium")
return GymnasiumWrapper(env)
elif "<class 'pettingzoo.utils.env" in base_classes[0] or "<class 'pettingzoo.utils.wrappers" in base_classes[0]:
if verbose:
logger.info("Environment wrapper: Petting Zoo")
return PettingZooWrapper(env)
elif "<class 'dm_env._environment.Environment'>" in base_classes:
if verbose:
logger.info("Environment wrapper: DeepMind")
return DeepMindWrapper(env)
elif "<class 'robosuite.environments." in base_classes[0]:
if verbose:
logger.info("Environment wrapper: Robosuite")
return RobosuiteWrapper(env)
elif "<class 'rlgpu.tasks.base.vec_task.VecTask'>" in base_classes:
if verbose:
logger.info("Environment wrapper: Isaac Gym (preview 2)")
return IsaacGymPreview2Wrapper(env)
def _get_wrapper_name(env, verbose):
def _in(value, container):
for item in container:
if value in item:
return True
return False

base_classes = [str(base).replace("<class '", "").replace("'>", "") for base in env.__class__.__bases__]
try:
base_classes += [str(base).replace("<class '", "").replace("'>", "") for base in env.unwrapped.__class__.__bases__]
except:
pass
base_classes = sorted(list(set(base_classes)))
if verbose:
logger.info("Environment wrapper: Isaac Gym (preview 3/4)")
return IsaacGymPreview3Wrapper(env) # preview 4 is the same as 3
elif wrapper == "gym":
logger.info(f"Environment wrapper: 'auto' (class: {', '.join(base_classes)})")

if _in("omni.isaac.lab.envs.manager_based_env.ManagerBasedEnv", base_classes) or _in("omni.isaac.lab.envs.direct_rl_env.DirectRLEnv", base_classes):
return "isaaclab"
elif _in("omni.isaac.gym.vec_env.vec_env_base.VecEnvBase", base_classes) or _in("omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT", base_classes):
return "omniverse-isaacgym"
elif _in("rlgpu.tasks.base.vec_task.VecTask", base_classes):
return "isaacgym-preview2"
elif _in("robosuite.environments.", base_classes):
return "robosuite"
elif _in("dm_env._environment.Environment.", base_classes):
return "dm"
elif _in("pettingzoo.utils.env", base_classes) or _in("pettingzoo.utils.wrappers", base_classes):
return "pettingzoo"
elif _in("gymnasium.core.Env", base_classes) or _in("gymnasium.core.Wrapper", base_classes):
return "gymnasium"
elif _in("gym.core.Env", base_classes) or _in("gym.core.Wrapper", base_classes):
return "gym"
return base_classes

if wrapper == "auto":
wrapper = _get_wrapper_name(env, verbose)

if wrapper == "gym":
if verbose:
logger.info("Environment wrapper: Gym")
return GymWrapper(env)
Expand Down
80 changes: 37 additions & 43 deletions skrl/envs/wrappers/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,50 +69,44 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
:return: Wrapped environment
:rtype: Wrapper or MultiAgentEnvWrapper
"""
if verbose:
logger.info("Environment class: {}".format(", ".join([str(base).replace("<class '", "").replace("'>", "") \
for base in env.__class__.__bases__])))
if wrapper == "auto":
base_classes = [str(base) for base in env.__class__.__bases__]
if "<class 'omni.isaac.gym.vec_env.vec_env_base.VecEnvBase'>" in base_classes or \
"<class 'omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT'>" in base_classes:
if verbose:
logger.info("Environment wrapper: Omniverse Isaac Gym")
return OmniverseIsaacGymWrapper(env)
elif isinstance(env, gym.core.Env) or isinstance(env, gym.core.Wrapper):
# isaaclab
if hasattr(env, "sim") and hasattr(env, "env_ns"):
if verbose:
logger.info("Environment wrapper: Isaac Lab")
return IsaacLabWrapper(env)
# gym
if verbose:
logger.info("Environment wrapper: Gym")
return GymWrapper(env)
elif isinstance(env, gymnasium.core.Env) or isinstance(env, gymnasium.core.Wrapper):
if verbose:
logger.info("Environment wrapper: Gymnasium")
return GymnasiumWrapper(env)
elif "<class 'pettingzoo.utils.env" in base_classes[0] or "<class 'pettingzoo.utils.wrappers" in base_classes[0]:
if verbose:
logger.info("Environment wrapper: Petting Zoo")
return PettingZooWrapper(env)
elif "<class 'dm_env._environment.Environment'>" in base_classes:
if verbose:
logger.info("Environment wrapper: DeepMind")
return DeepMindWrapper(env)
elif "<class 'robosuite.environments." in base_classes[0]:
if verbose:
logger.info("Environment wrapper: Robosuite")
return RobosuiteWrapper(env)
elif "<class 'rlgpu.tasks.base.vec_task.VecTask'>" in base_classes:
if verbose:
logger.info("Environment wrapper: Isaac Gym (preview 2)")
return IsaacGymPreview2Wrapper(env)
def _get_wrapper_name(env, verbose):
def _in(value, container):
for item in container:
if value in item:
return True
return False

base_classes = [str(base).replace("<class '", "").replace("'>", "") for base in env.__class__.__bases__]
try:
base_classes += [str(base).replace("<class '", "").replace("'>", "") for base in env.unwrapped.__class__.__bases__]
except:
pass
base_classes = sorted(list(set(base_classes)))
if verbose:
logger.info("Environment wrapper: Isaac Gym (preview 3/4)")
return IsaacGymPreview3Wrapper(env) # preview 4 is the same as 3
elif wrapper == "gym":
logger.info(f"Environment wrapper: 'auto' (class: {', '.join(base_classes)})")

if _in("omni.isaac.lab.envs.manager_based_env.ManagerBasedEnv", base_classes) or _in("omni.isaac.lab.envs.direct_rl_env.DirectRLEnv", base_classes):
return "isaaclab"
elif _in("omni.isaac.gym.vec_env.vec_env_base.VecEnvBase", base_classes) or _in("omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT", base_classes):
return "omniverse-isaacgym"
elif _in("rlgpu.tasks.base.vec_task.VecTask", base_classes):
return "isaacgym-preview2"
elif _in("robosuite.environments.", base_classes):
return "robosuite"
elif _in("dm_env._environment.Environment.", base_classes):
return "dm"
elif _in("pettingzoo.utils.env", base_classes) or _in("pettingzoo.utils.wrappers", base_classes):
return "pettingzoo"
elif _in("gymnasium.core.Env", base_classes) or _in("gymnasium.core.Wrapper", base_classes):
return "gymnasium"
elif _in("gym.core.Env", base_classes) or _in("gym.core.Wrapper", base_classes):
return "gym"
return base_classes

if wrapper == "auto":
wrapper = _get_wrapper_name(env, verbose)

if wrapper == "gym":
if verbose:
logger.info("Environment wrapper: Gym")
return GymWrapper(env)
Expand Down

0 comments on commit b9b07eb

Please sign in to comment.