Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize command resampling #224

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

"""Various command terms that can be used in the environment."""

from . import resampling
from .commands_cfg import (
NormalVelocityCommandCfg,
NullCommandCfg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import math
from dataclasses import MISSING

from omni.isaac.orbit.managers import CommandTermCfg
Expand All @@ -27,11 +26,6 @@ class NullCommandCfg(CommandTermCfg):

class_type: type = NullCommand

def __post_init__(self):
"""Post initialization."""
# set the resampling time range to infinity to avoid resampling
self.resampling_time_range = (math.inf, math.inf)


"""
Locomotion-specific command generators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class NullCommand(CommandTerm):
def __str__(self) -> str:
msg = "NullCommand:\n"
msg += "\tCommand dimension: N/A\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}"
msg += f"{self.cfg.resampling}"
return msg

"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, cfg: UniformPoseCommandCfg, env: BaseEnv):
def __str__(self) -> str:
msg = "UniformPoseCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
msg += f"{self.cfg.resampling}"
return msg

"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(self, cfg: TerrainBasedPositionCommandCfg, env: BaseEnv):
def __str__(self) -> str:
msg = "TerrainBasedPositionCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
msg += f"\tStanding probability: {self.cfg.rel_standing_envs}"
msg += f"{self.cfg.resampling}"
return msg

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from .fixed_frequency import FixedFrequency
from .random_chance import RandomChance
from .resampling_cfg import FixedFrequencyCfg, RandomChanceCfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
from typing import TYPE_CHECKING, Sequence

from omni.isaac.orbit.managers import ResamplingTerm

if TYPE_CHECKING:
from .resampling_cfg import FixedFrequencyCfg


class FixedFrequency(ResamplingTerm):
"""Fixed frequency resampling term.

The fixed frequency resampling term is used to resample commands at a fixed frequency.
When an environment is resampled, the time left is sampled from the range specified
in the FixedFrequencyCfg.
"""

def __init__(self, cfg: FixedFrequencyCfg, env):
super().__init__(cfg, env)

# -- time left before resampling
self.time_left = torch.zeros(self.num_envs, device=self.device)

def __str__(self) -> str:
msg = f"\t\tResampling time range: {self.cfg.resampling_time_range}"
return msg

def compute(self, dt: float):
"""Compute the environment ids to be resampled.

Args:
dt: The time step.
"""
# reduce the time left before resampling
self.time_left -= dt
# resample expired timers.
resample_env_ids = (self.time_left <= 0.0).nonzero().flatten()
return resample_env_ids

def reset(self, env_ids: Sequence[int]):
"""Reset the resampling term.

Resamples the time left from the cfg range.

Args:
env_ids: The environment ids to be reset.
"""
self.time_left[env_ids] = self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
from typing import TYPE_CHECKING

from omni.isaac.orbit.managers import ResamplingTerm

if TYPE_CHECKING:
from .resampling_cfg import RandomChanceCfg


class RandomChance(ResamplingTerm):
"""Random chance resampling term.

Commands are resampled with a fixed probability at every time step.
"""

def __init__(self, cfg: RandomChanceCfg, env):
super().__init__(cfg, env)

def __str__(self) -> str:
msg = f"\t\tResampling probability: {self.cfg.resampling_probability}"
return msg

def compute(self, dt: float):
"""Compute the environment ids to be resampled.

Args:
dt: The time step.
"""
# Note: uniform_(0, 1) is inclusive on 0 and exclusive on 1. So we need to use < instead of <=.
resample_prob_buf = (
torch.empty(self.num_envs, device=self.device).uniform_(0, 1) < self.cfg.resampling_probability
)
resample_env_ids = resample_prob_buf.nonzero(as_tuple=False).flatten()
return resample_env_ids
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from dataclasses import MISSING

from omni.isaac.orbit.managers import ResamplingTermCfg
from omni.isaac.orbit.utils import configclass

from .fixed_frequency import FixedFrequency
from .random_chance import RandomChance

"""
Fixed frequency resampling term.
"""


@configclass
class FixedFrequencyCfg(ResamplingTermCfg):
"""Configuration for the fixed frequency resampling term."""

class_type: type = FixedFrequency

resampling_time_range: tuple[float, float] = MISSING


"""
Random chance resampling term.
"""


@configclass
class RandomChanceCfg(ResamplingTermCfg):
"""Configuration for the fixed frequency resampling term."""

class_type: type = RandomChance

resampling_probability: float = MISSING
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def __str__(self) -> str:
"""Return a string representation of the command generator."""
msg = "UniformVelocityCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
msg += f"\tHeading command: {self.cfg.heading_command}\n"
if self.cfg.heading_command:
msg += f"\tHeading probability: {self.cfg.rel_heading_envs}\n"
msg += f"\tStanding probability: {self.cfg.rel_standing_envs}"
msg += f"{self.cfg.resampling}"
return msg

"""
Expand Down Expand Up @@ -130,15 +130,15 @@ def _update_command(self):
self.vel_command_b[standing_env_ids, :] = 0.0

def _update_metrics(self):
# time for which the command was executed
max_command_time = self.cfg.resampling_time_range[1]
max_command_step = max_command_time / self._env.step_dt
# logs data
self.metrics["error_vel_xy"] += (
torch.norm(self.vel_command_b[:, :2] - self.robot.data.root_lin_vel_b[:, :2], dim=-1) / max_command_step
command_age_after_step = self.command_age + self._env.step_dt
weight_old_value = self.command_age / command_age_after_step
weight_new_value = self._env.step_dt / command_age_after_step
self.metrics["error_vel_xy"] = weight_old_value * self.metrics["error_vel_xy"] + weight_new_value * torch.norm(
self.vel_command_b[:, :2] - self.robot.data.root_lin_vel_b[:, :2], dim=-1
)
self.metrics["error_vel_yaw"] += (
torch.abs(self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]) / max_command_step
self.metrics["error_vel_yaw"] = weight_old_value * self.metrics["error_vel_yaw"] + weight_new_value * torch.abs(
self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]
)

def _set_debug_vis_impl(self, debug_vis: bool):
Expand Down Expand Up @@ -229,8 +229,8 @@ def __str__(self) -> str:
"""Return a string representation of the command generator."""
msg = "NormalVelocityCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
msg += f"\tStanding probability: {self.cfg.rel_standing_envs}"
msg += f"{self.cfg.resampling}"
return msg

def _resample_command(self, env_ids):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
ObservationGroupCfg,
ObservationTermCfg,
RandomizationTermCfg,
ResamplingTermCfg,
RewardTermCfg,
TerminationTermCfg,
)
from .observation_manager import ObservationManager
from .randomization_manager import RandomizationManager
from .resampling_manager import ResamplingManager, ResamplingTerm
from .reward_manager import RewardManager
from .scene_entity_cfg import SceneEntityCfg
from .termination_manager import TerminationManager
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import CommandTermCfg
from .resampling_manager import ResamplingManager

if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
Expand All @@ -30,8 +31,8 @@ class CommandTerm(ManagerTermBase):
in the case of a goal-conditioned navigation task, the command term can be used to
generate a target position for the robot to navigate to.

It implements a resampling mechanism that allows the command to be resampled at a fixed
frequency. The resampling frequency can be specified in the configuration object.
The resampling mechanism is defined by resampling terms which are called at every
command update to determine which envs should be reset.
Additionally, it is possible to assign a visualization function to the command term
that can be used to visualize the command in the simulator.
"""
Expand All @@ -48,10 +49,12 @@ def __init__(self, cfg: CommandTermCfg, env: RLTaskEnv):
# create buffers to store the command
# -- metrics that can be used for logging
self.metrics = dict()
# -- time left before resampling
self.time_left = torch.zeros(self.num_envs, device=self.device)
# -- resampling manager
self.resampling_manager: ResamplingManager = ResamplingManager(cfg.resampling, env)
# -- counter for the number of times the command has been resampled within the current episode
self.command_counter = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
# -- timer for tracking the time since last resampling the command.
self.command_age = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)

# add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
self._debug_vis_handle = None
Expand Down Expand Up @@ -152,10 +155,10 @@ def compute(self, dt: float):
"""
# update the metrics based on current state
self._update_metrics()
# reduce the time left before resampling
self.time_left -= dt
# increase command age
self.command_age += dt
# resample the command if necessary
resample_env_ids = (self.time_left <= 0.0).nonzero().flatten()
resample_env_ids = self.resampling_manager.compute(dt)
if len(resample_env_ids) > 0:
self._resample(resample_env_ids)
# update the command
Expand All @@ -168,14 +171,16 @@ def compute(self, dt: float):
def _resample(self, env_ids: Sequence[int]):
"""Resample the command.

This function resamples the command and time for which the command is applied for the
specified environment indices.
This function resamples the command and notifies all sampling terms
that the commands have been resampled.

Args:
env_ids: The list of environment IDs to resample.
"""
# resample the time left before resampling
self.time_left[env_ids] = self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)
# notify resampling manager of env ids which are resampled.
self.resampling_manager.reset(env_ids)
# set the command age to zero
self.command_age[env_ids] = 0.0
# increment the command counter
self.command_counter[env_ids] += 1
# resample the command
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .action_manager import ActionTerm
from .command_manager import CommandTerm
from .manager_base import ManagerTermBase
from .resampling_manager import ResamplingTerm


@configclass
Expand Down Expand Up @@ -87,8 +88,8 @@ class CommandTermCfg:
The class should inherit from :class:`omni.isaac.orbit.managers.command_manager.CommandTerm`.
"""

resampling_time_range: tuple[float, float] = MISSING
"""Time before commands are changed [s]."""
resampling: dict[str, ResamplingTermCfg] | None = None
"""Terms used for resampling the command."""
debug_vis: bool = False
"""Whether to visualize debug information. Defaults to False."""

Expand Down Expand Up @@ -195,6 +196,22 @@ class RandomizationTermCfg(ManagerTermBaseCfg):
"""


##
# Resampling manager.
##


@configclass
class ResamplingTermCfg:
"""Configuration for a resampling term."""

class_type: type[ResamplingTerm] = MISSING
"""The associated command term class to use.

The class should inherit from :class:`omni.isaac.orbit.managers.resampling_manager.ResamplingTerm`.
"""


##
# Reward manager.
##
Expand Down
Loading