Skip to content

Commit

Permalink
Update FrameTransformer to handle external frames
Browse files Browse the repository at this point in the history
  • Loading branch information
jsmith-bdai committed Aug 22, 2024
1 parent ad4ec6e commit 419f127
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
#
# SPDX-License-Identifier: BSD-3-Clause

# i Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import re
import torch
from collections.abc import Sequence
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -136,9 +142,9 @@ def _initialize_impl(self):
self._source_frame_offset_pos = source_frame_offset_pos.unsqueeze(0).repeat(self._num_envs, 1)
self._source_frame_offset_quat = source_frame_offset_quat.unsqueeze(0).repeat(self._num_envs, 1)

# Keep track of mapping from the rigid body name to the desired frame, as there may be multiple frames
# Keep track of mapping from the rigid body name to the desired frames and prim path, as there may be multiple frames
# based upon the same body name and we don't want to create unnecessary views
body_names_to_frames: dict[str, set[str]] = {}
body_names_to_frames: dict[str, dict[str, set[str]]] = {}
# The offsets associated with each target frame
target_offsets: dict[str, dict[str, torch.Tensor]] = {}
# The frames whose offsets are not identity
Expand Down Expand Up @@ -180,9 +186,10 @@ def _initialize_impl(self):

# Keep track of which frames are associated with which bodies
if body_name in body_names_to_frames:
body_names_to_frames[body_name].add(frame_name)
body_names_to_frames[body_name]["frames"].add(frame_name)
else:
body_names_to_frames[body_name] = {frame_name}
# Store the first matching prim path
body_names_to_frames[body_name] = {"frames": {frame_name}, "prim_path": matching_prim_path}

if offset is not None:
offset_pos = torch.tensor(offset.pos, device=self.device)
Expand All @@ -206,29 +213,42 @@ def _initialize_impl(self):
)

# The names of bodies that RigidPrimView will be tracking to later extract transforms from
tracked_body_names = list(body_names_to_frames.keys())
# Construct regex expression for the body names
body_names_regex = r"(" + "|".join(tracked_body_names) + r")"
body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}"
tracked_prim_paths = [body_names_to_frames[body_name]["prim_path"] for body_name in body_names_to_frames.keys()]
tracked_body_names = [body_name for body_name in body_names_to_frames.keys()]

body_names_regex = [tracked_prim_path.replace("env_0", "env_*") for tracked_prim_path in tracked_prim_paths]

# Create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# Create a prim view for all frames and initialize it
# order of transforms coming out of view will be source frame followed by target frame(s)
self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*"))
# self._frame_physx_view = self._physics_sim_view.create_rigid_body_view('/World/envs/env_*/{Robot/base,Robot/LH_SHANK,Robot/LF_SHANK,cube}')
self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex)

# Determine the order in which regex evaluated body names so we can later index into frame transforms
# by frame name correctly
all_prim_paths = self._frame_physx_view.prim_paths

def extract_env_num(item):
match = re.search(r"env_(\d+)(.*)", item)
return (int(match.group(1)), match.group(2))

# Find the indices that would reorganize output to be per environment. We want `env_1/blah` to come before `env_11/blah`
# so we need to use a custom key function
sorted_indexed_prim_paths = sorted(list(enumerate(all_prim_paths)), key=lambda x: extract_env_num(x[1]))
self._per_env_indices = [index for index, _ in sorted_indexed_prim_paths]
sorted_prim_paths = [all_prim_paths[i] for i in self._per_env_indices]

# Only need first env as the names and their ordering are the same across environments
first_env_prim_paths = all_prim_paths[0 : len(tracked_body_names)]
first_env_prim_paths = [prim_path for prim_path in sorted_prim_paths if "env_0" in prim_path]
first_env_body_names = [first_env_prim_path.split("/")[-1] for first_env_prim_path in first_env_prim_paths]

# Re-parse the list as it may have moved when resolving regex above
# -- source frame
self._source_frame_body_name = self.cfg.prim_path.split("/")[-1]
source_frame_index = first_env_body_names.index(self._source_frame_body_name)

# -- target frames
self._target_frame_body_names = first_env_body_names[:]
self._target_frame_body_names.remove(self._source_frame_body_name)
Expand All @@ -248,11 +268,15 @@ def _initialize_impl(self):
# when updating sensor in _update_buffers_impl
duplicate_frame_indices = []

# The position and rotation components of target frame offsets
self._target_frame_offset_pos = torch.zeros(0, 3, device=self.device)
self._target_frame_offset_quat = torch.zeros(0, 4, device=self.device)

# Go through each body name and determine the number of duplicates we need for that frame
# and extract the offsets. This is all done to handles the case where multiple frames
# reference the same body, but have different names and/or offsets
for i, body_name in enumerate(self._target_frame_body_names):
for frame in body_names_to_frames[body_name]:
for frame in body_names_to_frames[body_name]["frames"]:
target_frame_offset_pos.append(target_offsets[frame]["pos"])
target_frame_offset_quat.append(target_offsets[frame]["quat"])
self._target_frame_names.append(frame)
Expand Down Expand Up @@ -288,6 +312,10 @@ def _update_buffers_impl(self, env_ids: Sequence[int]):
# Extract transforms from view - shape is:
# (the total number of source and target body frames being tracked * self._num_envs, 7)
transforms = self._frame_physx_view.get_transforms()

# Reorder the transforms to be per environment as is expected of SensorData
transforms = transforms[self._per_env_indices]

# Convert quaternions as PhysX uses xyzw form
transforms[:, 3:] = convert_quat(transforms[:, 3:], to="wxyz")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import omni.isaac.lab.sim as sim_utils
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.assets import RigidObjectCfg
from omni.isaac.lab.scene import InteractiveScene, InteractiveSceneCfg
from omni.isaac.lab.sensors import FrameTransformerCfg, OffsetCfg
from omni.isaac.lab.terrains import TerrainImporterCfg
Expand Down Expand Up @@ -62,6 +63,19 @@ class MySceneCfg(InteractiveSceneCfg):
# sensors - frame transformer (filled inside unit test)
frame_transformer: FrameTransformerCfg = None

# block
cube: RigidObjectCfg = RigidObjectCfg(
prim_path="{ENV_REGEX_NS}/cube",
spawn=sim_utils.CuboidCfg(
size=(0.2, 0.2, 0.2),
rigid_props=sim_utils.RigidBodyPropertiesCfg(max_depenetration_velocity=1.0),
mass_props=sim_utils.MassPropertiesCfg(mass=1.0),
physics_material=sim_utils.RigidBodyMaterialCfg(),
visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.5, 0.0, 0.0)),
),
init_state=RigidObjectCfg.InitialStateCfg(pos=(2.0, 0.0, 5)),
)


class TestFrameTransformer(unittest.TestCase):
"""Test for frame transformer sensor."""
Expand All @@ -71,7 +85,7 @@ def setUp(self):
# Create a new stage
stage_utils.create_new_stage()
# Load kit helper
self.sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.005))
self.sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.005, device="cpu"))
# Set main camera
self.sim.set_camera_view(eye=[5, 5, 5], target=[0.0, 0.0, 0.0])

Expand All @@ -90,8 +104,7 @@ def tearDown(self):
def test_frame_transformer_feet_wrt_base(self):
"""Test feet transformations w.r.t. base source frame.
In this test, the source frame is the robot base. This frame is at index 0, when
the frame bodies are sorted in the order of the regex matching in the frame transformer.
In this test, the source frame is the robot base.
"""
# Spawn things into stage
scene_cfg = MySceneCfg(num_envs=32, env_spacing=5.0, lazy_sensor_update=False)
Expand Down Expand Up @@ -141,9 +154,15 @@ def test_frame_transformer_feet_wrt_base(self):
feet_indices, feet_names = scene.articulations["robot"].find_bodies(
["LF_FOOT", "RF_FOOT", "LH_FOOT", "RH_FOOT"]
)
# Check names are parsed the same order
user_feet_names = [f"{name}_USER" for name in feet_names]
self.assertListEqual(scene.sensors["frame_transformer"].data.target_frame_names, user_feet_names)

target_frame_names = scene.sensors["frame_transformer"].data.target_frame_names

# Reorder the feet indices to match the order of the target frames with _USER suffix removed
target_frame_names = [name.split("_USER")[0] for name in target_frame_names]

# Find the indices of the feet in the order of the target frames
reordering_indices = [feet_names.index(name) for name in target_frame_names]
feet_indices = [feet_indices[i] for i in reordering_indices]

# default joint targets
default_actions = scene.articulations["robot"].data.default_joint_pos.clone()
Expand Down Expand Up @@ -185,6 +204,7 @@ def test_frame_transformer_feet_wrt_base(self):
source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w
feet_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w
feet_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w

# check if they are same
torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf, rtol=1e-3, atol=1e-3)
Expand Down Expand Up @@ -302,6 +322,87 @@ def test_frame_transformer_feet_wrt_thigh(self):
torch.testing.assert_close(feet_pos_source_tf[:, index], foot_pos_b, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_quat_source_tf[:, index], foot_quat_b, rtol=1e-3, atol=1e-3)

def test_frame_transformer_body_wrt_cube(self):
"""Test body transformation w.r.t. base source frame.
In this test, the source frame is the robot base.
The target_frame is a cube in the scene.
"""
# Spawn things into stage
scene_cfg = MySceneCfg(num_envs=2, env_spacing=5.0, lazy_sensor_update=False)
scene_cfg.frame_transformer = FrameTransformerCfg(
prim_path="{ENV_REGEX_NS}/Robot/base",
target_frames=[
FrameTransformerCfg.FrameCfg(
name="CUBE_USER",
prim_path="{ENV_REGEX_NS}/cube",
),
],
)
scene = InteractiveScene(scene_cfg)

# Play the simulator
self.sim.reset()

# default joint targets
default_actions = scene.articulations["robot"].data.default_joint_pos.clone()
# Define simulation stepping
sim_dt = self.sim.get_physics_dt()
# Simulate physics
for count in range(100):
# # reset
if count % 25 == 0:
# reset root state
root_state = scene.articulations["robot"].data.default_root_state.clone()
root_state[:, :3] += scene.env_origins
joint_pos = scene.articulations["robot"].data.default_joint_pos
joint_vel = scene.articulations["robot"].data.default_joint_vel
# -- set root state
# -- robot
scene.articulations["robot"].write_root_state_to_sim(root_state)
scene.articulations["robot"].write_joint_state_to_sim(joint_pos, joint_vel)
# reset buffers
scene.reset()

# set joint targets
robot_actions = default_actions + 0.5 * torch.randn_like(default_actions)
scene.articulations["robot"].set_joint_position_target(robot_actions)
# write data to sim
scene.write_data_to_sim()
# perform step
self.sim.step()
# read data from sim
scene.update(sim_dt)

# check absolute frame transforms in world frame
# -- ground-truth
root_pose_w = scene.articulations["robot"].data.root_state_w[:, :7]
cube_pos_w_gt = scene.rigid_objects["cube"].data.root_state_w[:, :3]
cube_quat_w_gt = scene.rigid_objects["cube"].data.root_state_w[:, 3:7]
# -- frame transformer
source_pos_w_tf = scene.sensors["frame_transformer"].data.source_pos_w
source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w
cube_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w.squeeze()
cube_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w.squeeze()

# check if they are same
torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(cube_pos_w_gt, cube_pos_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(cube_quat_w_gt, cube_quat_w_tf, rtol=1e-3, atol=1e-3)

# check if relative transforms are same
cube_pos_source_tf = scene.sensors["frame_transformer"].data.target_pos_source
cube_quat_source_tf = scene.sensors["frame_transformer"].data.target_quat_source
# ground-truth
cube_pos_b, cube_quat_b = math_utils.subtract_frame_transforms(
root_pose_w[:, :3], root_pose_w[:, 3:], cube_pos_w_tf, cube_quat_w_tf
)
# check if they are same
torch.testing.assert_close(cube_pos_source_tf[:, 0], cube_pos_b, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(cube_quat_source_tf[:, 0], cube_quat_b, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
run_tests()

0 comments on commit 419f127

Please sign in to comment.