diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/sensors/frame_transformer/frame_transformer.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/sensors/frame_transformer/frame_transformer.py index 320007cabd..79783fd0cb 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/sensors/frame_transformer/frame_transformer.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/sensors/frame_transformer/frame_transformer.py @@ -3,8 +3,14 @@ # # SPDX-License-Identifier: BSD-3-Clause +# i Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + from __future__ import annotations +import re import torch from collections.abc import Sequence from typing import TYPE_CHECKING @@ -136,9 +142,9 @@ def _initialize_impl(self): self._source_frame_offset_pos = source_frame_offset_pos.unsqueeze(0).repeat(self._num_envs, 1) self._source_frame_offset_quat = source_frame_offset_quat.unsqueeze(0).repeat(self._num_envs, 1) - # Keep track of mapping from the rigid body name to the desired frame, as there may be multiple frames + # Keep track of mapping from the rigid body name to the desired frames and prim path, as there may be multiple frames # based upon the same body name and we don't want to create unnecessary views - body_names_to_frames: dict[str, set[str]] = {} + body_names_to_frames: dict[str, dict[str, set[str]]] = {} # The offsets associated with each target frame target_offsets: dict[str, dict[str, torch.Tensor]] = {} # The frames whose offsets are not identity @@ -180,9 +186,10 @@ def _initialize_impl(self): # Keep track of which frames are associated with which bodies if body_name in body_names_to_frames: - body_names_to_frames[body_name].add(frame_name) + body_names_to_frames[body_name]["frames"].add(frame_name) else: - body_names_to_frames[body_name] = {frame_name} + # Store the first matching prim path + body_names_to_frames[body_name] = {"frames": {frame_name}, "prim_path": matching_prim_path} if offset is not None: offset_pos = torch.tensor(offset.pos, device=self.device) @@ -206,29 +213,42 @@ def _initialize_impl(self): ) # The names of bodies that RigidPrimView will be tracking to later extract transforms from - tracked_body_names = list(body_names_to_frames.keys()) - # Construct regex expression for the body names - body_names_regex = r"(" + "|".join(tracked_body_names) + r")" - body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}" + tracked_prim_paths = [body_names_to_frames[body_name]["prim_path"] for body_name in body_names_to_frames.keys()] + tracked_body_names = [body_name for body_name in body_names_to_frames.keys()] + + body_names_regex = [tracked_prim_path.replace("env_0", "env_*") for tracked_prim_path in tracked_prim_paths] + # Create simulation view self._physics_sim_view = physx.create_simulation_view(self._backend) self._physics_sim_view.set_subspace_roots("/") # Create a prim view for all frames and initialize it # order of transforms coming out of view will be source frame followed by target frame(s) - self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*")) + # self._frame_physx_view = self._physics_sim_view.create_rigid_body_view('/World/envs/env_*/{Robot/base,Robot/LH_SHANK,Robot/LF_SHANK,cube}') + self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex) # Determine the order in which regex evaluated body names so we can later index into frame transforms # by frame name correctly all_prim_paths = self._frame_physx_view.prim_paths + def extract_env_num(item): + match = re.search(r"env_(\d+)(.*)", item) + return (int(match.group(1)), match.group(2)) + + # Find the indices that would reorganize output to be per environment. We want `env_1/blah` to come before `env_11/blah` + # so we need to use a custom key function + sorted_indexed_prim_paths = sorted(list(enumerate(all_prim_paths)), key=lambda x: extract_env_num(x[1])) + self._per_env_indices = [index for index, _ in sorted_indexed_prim_paths] + sorted_prim_paths = [all_prim_paths[i] for i in self._per_env_indices] + # Only need first env as the names and their ordering are the same across environments - first_env_prim_paths = all_prim_paths[0 : len(tracked_body_names)] + first_env_prim_paths = [prim_path for prim_path in sorted_prim_paths if "env_0" in prim_path] first_env_body_names = [first_env_prim_path.split("/")[-1] for first_env_prim_path in first_env_prim_paths] # Re-parse the list as it may have moved when resolving regex above # -- source frame self._source_frame_body_name = self.cfg.prim_path.split("/")[-1] source_frame_index = first_env_body_names.index(self._source_frame_body_name) + # -- target frames self._target_frame_body_names = first_env_body_names[:] self._target_frame_body_names.remove(self._source_frame_body_name) @@ -248,11 +268,15 @@ def _initialize_impl(self): # when updating sensor in _update_buffers_impl duplicate_frame_indices = [] + # The position and rotation components of target frame offsets + self._target_frame_offset_pos = torch.zeros(0, 3, device=self.device) + self._target_frame_offset_quat = torch.zeros(0, 4, device=self.device) + # Go through each body name and determine the number of duplicates we need for that frame # and extract the offsets. This is all done to handles the case where multiple frames # reference the same body, but have different names and/or offsets for i, body_name in enumerate(self._target_frame_body_names): - for frame in body_names_to_frames[body_name]: + for frame in body_names_to_frames[body_name]["frames"]: target_frame_offset_pos.append(target_offsets[frame]["pos"]) target_frame_offset_quat.append(target_offsets[frame]["quat"]) self._target_frame_names.append(frame) @@ -288,6 +312,10 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): # Extract transforms from view - shape is: # (the total number of source and target body frames being tracked * self._num_envs, 7) transforms = self._frame_physx_view.get_transforms() + + # Reorder the transforms to be per environment as is expected of SensorData + transforms = transforms[self._per_env_indices] + # Convert quaternions as PhysX uses xyzw form transforms[:, 3:] = convert_quat(transforms[:, 3:], to="wxyz") diff --git a/source/extensions/omni.isaac.lab/test/sensors/test_frame_transformer.py b/source/extensions/omni.isaac.lab/test/sensors/test_frame_transformer.py index f379bc86b2..47e32a65cd 100644 --- a/source/extensions/omni.isaac.lab/test/sensors/test_frame_transformer.py +++ b/source/extensions/omni.isaac.lab/test/sensors/test_frame_transformer.py @@ -26,6 +26,7 @@ import omni.isaac.lab.sim as sim_utils import omni.isaac.lab.utils.math as math_utils +from omni.isaac.lab.assets import RigidObjectCfg from omni.isaac.lab.scene import InteractiveScene, InteractiveSceneCfg from omni.isaac.lab.sensors import FrameTransformerCfg, OffsetCfg from omni.isaac.lab.terrains import TerrainImporterCfg @@ -62,6 +63,19 @@ class MySceneCfg(InteractiveSceneCfg): # sensors - frame transformer (filled inside unit test) frame_transformer: FrameTransformerCfg = None + # block + cube: RigidObjectCfg = RigidObjectCfg( + prim_path="{ENV_REGEX_NS}/cube", + spawn=sim_utils.CuboidCfg( + size=(0.2, 0.2, 0.2), + rigid_props=sim_utils.RigidBodyPropertiesCfg(max_depenetration_velocity=1.0), + mass_props=sim_utils.MassPropertiesCfg(mass=1.0), + physics_material=sim_utils.RigidBodyMaterialCfg(), + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.5, 0.0, 0.0)), + ), + init_state=RigidObjectCfg.InitialStateCfg(pos=(2.0, 0.0, 5)), + ) + class TestFrameTransformer(unittest.TestCase): """Test for frame transformer sensor.""" @@ -71,7 +85,7 @@ def setUp(self): # Create a new stage stage_utils.create_new_stage() # Load kit helper - self.sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.005)) + self.sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.005, device="cpu")) # Set main camera self.sim.set_camera_view(eye=[5, 5, 5], target=[0.0, 0.0, 0.0]) @@ -90,8 +104,7 @@ def tearDown(self): def test_frame_transformer_feet_wrt_base(self): """Test feet transformations w.r.t. base source frame. - In this test, the source frame is the robot base. This frame is at index 0, when - the frame bodies are sorted in the order of the regex matching in the frame transformer. + In this test, the source frame is the robot base. """ # Spawn things into stage scene_cfg = MySceneCfg(num_envs=32, env_spacing=5.0, lazy_sensor_update=False) @@ -141,9 +154,15 @@ def test_frame_transformer_feet_wrt_base(self): feet_indices, feet_names = scene.articulations["robot"].find_bodies( ["LF_FOOT", "RF_FOOT", "LH_FOOT", "RH_FOOT"] ) - # Check names are parsed the same order - user_feet_names = [f"{name}_USER" for name in feet_names] - self.assertListEqual(scene.sensors["frame_transformer"].data.target_frame_names, user_feet_names) + + target_frame_names = scene.sensors["frame_transformer"].data.target_frame_names + + # Reorder the feet indices to match the order of the target frames with _USER suffix removed + target_frame_names = [name.split("_USER")[0] for name in target_frame_names] + + # Find the indices of the feet in the order of the target frames + reordering_indices = [feet_names.index(name) for name in target_frame_names] + feet_indices = [feet_indices[i] for i in reordering_indices] # default joint targets default_actions = scene.articulations["robot"].data.default_joint_pos.clone() @@ -185,6 +204,7 @@ def test_frame_transformer_feet_wrt_base(self): source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w feet_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w feet_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w + # check if they are same torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf, rtol=1e-3, atol=1e-3) torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf, rtol=1e-3, atol=1e-3) @@ -302,6 +322,87 @@ def test_frame_transformer_feet_wrt_thigh(self): torch.testing.assert_close(feet_pos_source_tf[:, index], foot_pos_b, rtol=1e-3, atol=1e-3) torch.testing.assert_close(feet_quat_source_tf[:, index], foot_quat_b, rtol=1e-3, atol=1e-3) + def test_frame_transformer_body_wrt_cube(self): + """Test body transformation w.r.t. base source frame. + + In this test, the source frame is the robot base. + + The target_frame is a cube in the scene. + """ + # Spawn things into stage + scene_cfg = MySceneCfg(num_envs=2, env_spacing=5.0, lazy_sensor_update=False) + scene_cfg.frame_transformer = FrameTransformerCfg( + prim_path="{ENV_REGEX_NS}/Robot/base", + target_frames=[ + FrameTransformerCfg.FrameCfg( + name="CUBE_USER", + prim_path="{ENV_REGEX_NS}/cube", + ), + ], + ) + scene = InteractiveScene(scene_cfg) + + # Play the simulator + self.sim.reset() + + # default joint targets + default_actions = scene.articulations["robot"].data.default_joint_pos.clone() + # Define simulation stepping + sim_dt = self.sim.get_physics_dt() + # Simulate physics + for count in range(100): + # # reset + if count % 25 == 0: + # reset root state + root_state = scene.articulations["robot"].data.default_root_state.clone() + root_state[:, :3] += scene.env_origins + joint_pos = scene.articulations["robot"].data.default_joint_pos + joint_vel = scene.articulations["robot"].data.default_joint_vel + # -- set root state + # -- robot + scene.articulations["robot"].write_root_state_to_sim(root_state) + scene.articulations["robot"].write_joint_state_to_sim(joint_pos, joint_vel) + # reset buffers + scene.reset() + + # set joint targets + robot_actions = default_actions + 0.5 * torch.randn_like(default_actions) + scene.articulations["robot"].set_joint_position_target(robot_actions) + # write data to sim + scene.write_data_to_sim() + # perform step + self.sim.step() + # read data from sim + scene.update(sim_dt) + + # check absolute frame transforms in world frame + # -- ground-truth + root_pose_w = scene.articulations["robot"].data.root_state_w[:, :7] + cube_pos_w_gt = scene.rigid_objects["cube"].data.root_state_w[:, :3] + cube_quat_w_gt = scene.rigid_objects["cube"].data.root_state_w[:, 3:7] + # -- frame transformer + source_pos_w_tf = scene.sensors["frame_transformer"].data.source_pos_w + source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w + cube_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w.squeeze() + cube_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w.squeeze() + + # check if they are same + torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(cube_pos_w_gt, cube_pos_w_tf, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(cube_quat_w_gt, cube_quat_w_tf, rtol=1e-3, atol=1e-3) + + # check if relative transforms are same + cube_pos_source_tf = scene.sensors["frame_transformer"].data.target_pos_source + cube_quat_source_tf = scene.sensors["frame_transformer"].data.target_quat_source + # ground-truth + cube_pos_b, cube_quat_b = math_utils.subtract_frame_transforms( + root_pose_w[:, :3], root_pose_w[:, 3:], cube_pos_w_tf, cube_quat_w_tf + ) + # check if they are same + torch.testing.assert_close(cube_pos_source_tf[:, 0], cube_pos_b, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(cube_quat_source_tf[:, 0], cube_quat_b, rtol=1e-3, atol=1e-3) + if __name__ == "__main__": run_tests()