Skip to content

Commit 419f127

Browse files
committed
Update FrameTransformer to handle external frames
1 parent ad4ec6e commit 419f127

File tree

2 files changed

+146
-17
lines changed

2 files changed

+146
-17
lines changed

source/extensions/omni.isaac.lab/omni/isaac/lab/sensors/frame_transformer/frame_transformer.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
#
44
# SPDX-License-Identifier: BSD-3-Clause
55

6+
# i Copyright (c) 2022-2024, The Isaac Lab Project Developers.
7+
# All rights reserved.
8+
#
9+
# SPDX-License-Identifier: BSD-3-Clause
10+
611
from __future__ import annotations
712

13+
import re
814
import torch
915
from collections.abc import Sequence
1016
from typing import TYPE_CHECKING
@@ -136,9 +142,9 @@ def _initialize_impl(self):
136142
self._source_frame_offset_pos = source_frame_offset_pos.unsqueeze(0).repeat(self._num_envs, 1)
137143
self._source_frame_offset_quat = source_frame_offset_quat.unsqueeze(0).repeat(self._num_envs, 1)
138144

139-
# Keep track of mapping from the rigid body name to the desired frame, as there may be multiple frames
145+
# Keep track of mapping from the rigid body name to the desired frames and prim path, as there may be multiple frames
140146
# based upon the same body name and we don't want to create unnecessary views
141-
body_names_to_frames: dict[str, set[str]] = {}
147+
body_names_to_frames: dict[str, dict[str, set[str]]] = {}
142148
# The offsets associated with each target frame
143149
target_offsets: dict[str, dict[str, torch.Tensor]] = {}
144150
# The frames whose offsets are not identity
@@ -180,9 +186,10 @@ def _initialize_impl(self):
180186

181187
# Keep track of which frames are associated with which bodies
182188
if body_name in body_names_to_frames:
183-
body_names_to_frames[body_name].add(frame_name)
189+
body_names_to_frames[body_name]["frames"].add(frame_name)
184190
else:
185-
body_names_to_frames[body_name] = {frame_name}
191+
# Store the first matching prim path
192+
body_names_to_frames[body_name] = {"frames": {frame_name}, "prim_path": matching_prim_path}
186193

187194
if offset is not None:
188195
offset_pos = torch.tensor(offset.pos, device=self.device)
@@ -206,29 +213,42 @@ def _initialize_impl(self):
206213
)
207214

208215
# The names of bodies that RigidPrimView will be tracking to later extract transforms from
209-
tracked_body_names = list(body_names_to_frames.keys())
210-
# Construct regex expression for the body names
211-
body_names_regex = r"(" + "|".join(tracked_body_names) + r")"
212-
body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}"
216+
tracked_prim_paths = [body_names_to_frames[body_name]["prim_path"] for body_name in body_names_to_frames.keys()]
217+
tracked_body_names = [body_name for body_name in body_names_to_frames.keys()]
218+
219+
body_names_regex = [tracked_prim_path.replace("env_0", "env_*") for tracked_prim_path in tracked_prim_paths]
220+
213221
# Create simulation view
214222
self._physics_sim_view = physx.create_simulation_view(self._backend)
215223
self._physics_sim_view.set_subspace_roots("/")
216224
# Create a prim view for all frames and initialize it
217225
# order of transforms coming out of view will be source frame followed by target frame(s)
218-
self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*"))
226+
# self._frame_physx_view = self._physics_sim_view.create_rigid_body_view('/World/envs/env_*/{Robot/base,Robot/LH_SHANK,Robot/LF_SHANK,cube}')
227+
self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex)
219228

220229
# Determine the order in which regex evaluated body names so we can later index into frame transforms
221230
# by frame name correctly
222231
all_prim_paths = self._frame_physx_view.prim_paths
223232

233+
def extract_env_num(item):
234+
match = re.search(r"env_(\d+)(.*)", item)
235+
return (int(match.group(1)), match.group(2))
236+
237+
# Find the indices that would reorganize output to be per environment. We want `env_1/blah` to come before `env_11/blah`
238+
# so we need to use a custom key function
239+
sorted_indexed_prim_paths = sorted(list(enumerate(all_prim_paths)), key=lambda x: extract_env_num(x[1]))
240+
self._per_env_indices = [index for index, _ in sorted_indexed_prim_paths]
241+
sorted_prim_paths = [all_prim_paths[i] for i in self._per_env_indices]
242+
224243
# Only need first env as the names and their ordering are the same across environments
225-
first_env_prim_paths = all_prim_paths[0 : len(tracked_body_names)]
244+
first_env_prim_paths = [prim_path for prim_path in sorted_prim_paths if "env_0" in prim_path]
226245
first_env_body_names = [first_env_prim_path.split("/")[-1] for first_env_prim_path in first_env_prim_paths]
227246

228247
# Re-parse the list as it may have moved when resolving regex above
229248
# -- source frame
230249
self._source_frame_body_name = self.cfg.prim_path.split("/")[-1]
231250
source_frame_index = first_env_body_names.index(self._source_frame_body_name)
251+
232252
# -- target frames
233253
self._target_frame_body_names = first_env_body_names[:]
234254
self._target_frame_body_names.remove(self._source_frame_body_name)
@@ -248,11 +268,15 @@ def _initialize_impl(self):
248268
# when updating sensor in _update_buffers_impl
249269
duplicate_frame_indices = []
250270

271+
# The position and rotation components of target frame offsets
272+
self._target_frame_offset_pos = torch.zeros(0, 3, device=self.device)
273+
self._target_frame_offset_quat = torch.zeros(0, 4, device=self.device)
274+
251275
# Go through each body name and determine the number of duplicates we need for that frame
252276
# and extract the offsets. This is all done to handles the case where multiple frames
253277
# reference the same body, but have different names and/or offsets
254278
for i, body_name in enumerate(self._target_frame_body_names):
255-
for frame in body_names_to_frames[body_name]:
279+
for frame in body_names_to_frames[body_name]["frames"]:
256280
target_frame_offset_pos.append(target_offsets[frame]["pos"])
257281
target_frame_offset_quat.append(target_offsets[frame]["quat"])
258282
self._target_frame_names.append(frame)
@@ -288,6 +312,10 @@ def _update_buffers_impl(self, env_ids: Sequence[int]):
288312
# Extract transforms from view - shape is:
289313
# (the total number of source and target body frames being tracked * self._num_envs, 7)
290314
transforms = self._frame_physx_view.get_transforms()
315+
316+
# Reorder the transforms to be per environment as is expected of SensorData
317+
transforms = transforms[self._per_env_indices]
318+
291319
# Convert quaternions as PhysX uses xyzw form
292320
transforms[:, 3:] = convert_quat(transforms[:, 3:], to="wxyz")
293321

source/extensions/omni.isaac.lab/test/sensors/test_frame_transformer.py

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import omni.isaac.lab.sim as sim_utils
2828
import omni.isaac.lab.utils.math as math_utils
29+
from omni.isaac.lab.assets import RigidObjectCfg
2930
from omni.isaac.lab.scene import InteractiveScene, InteractiveSceneCfg
3031
from omni.isaac.lab.sensors import FrameTransformerCfg, OffsetCfg
3132
from omni.isaac.lab.terrains import TerrainImporterCfg
@@ -62,6 +63,19 @@ class MySceneCfg(InteractiveSceneCfg):
6263
# sensors - frame transformer (filled inside unit test)
6364
frame_transformer: FrameTransformerCfg = None
6465

66+
# block
67+
cube: RigidObjectCfg = RigidObjectCfg(
68+
prim_path="{ENV_REGEX_NS}/cube",
69+
spawn=sim_utils.CuboidCfg(
70+
size=(0.2, 0.2, 0.2),
71+
rigid_props=sim_utils.RigidBodyPropertiesCfg(max_depenetration_velocity=1.0),
72+
mass_props=sim_utils.MassPropertiesCfg(mass=1.0),
73+
physics_material=sim_utils.RigidBodyMaterialCfg(),
74+
visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.5, 0.0, 0.0)),
75+
),
76+
init_state=RigidObjectCfg.InitialStateCfg(pos=(2.0, 0.0, 5)),
77+
)
78+
6579

6680
class TestFrameTransformer(unittest.TestCase):
6781
"""Test for frame transformer sensor."""
@@ -71,7 +85,7 @@ def setUp(self):
7185
# Create a new stage
7286
stage_utils.create_new_stage()
7387
# Load kit helper
74-
self.sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.005))
88+
self.sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.005, device="cpu"))
7589
# Set main camera
7690
self.sim.set_camera_view(eye=[5, 5, 5], target=[0.0, 0.0, 0.0])
7791

@@ -90,8 +104,7 @@ def tearDown(self):
90104
def test_frame_transformer_feet_wrt_base(self):
91105
"""Test feet transformations w.r.t. base source frame.
92106
93-
In this test, the source frame is the robot base. This frame is at index 0, when
94-
the frame bodies are sorted in the order of the regex matching in the frame transformer.
107+
In this test, the source frame is the robot base.
95108
"""
96109
# Spawn things into stage
97110
scene_cfg = MySceneCfg(num_envs=32, env_spacing=5.0, lazy_sensor_update=False)
@@ -141,9 +154,15 @@ def test_frame_transformer_feet_wrt_base(self):
141154
feet_indices, feet_names = scene.articulations["robot"].find_bodies(
142155
["LF_FOOT", "RF_FOOT", "LH_FOOT", "RH_FOOT"]
143156
)
144-
# Check names are parsed the same order
145-
user_feet_names = [f"{name}_USER" for name in feet_names]
146-
self.assertListEqual(scene.sensors["frame_transformer"].data.target_frame_names, user_feet_names)
157+
158+
target_frame_names = scene.sensors["frame_transformer"].data.target_frame_names
159+
160+
# Reorder the feet indices to match the order of the target frames with _USER suffix removed
161+
target_frame_names = [name.split("_USER")[0] for name in target_frame_names]
162+
163+
# Find the indices of the feet in the order of the target frames
164+
reordering_indices = [feet_names.index(name) for name in target_frame_names]
165+
feet_indices = [feet_indices[i] for i in reordering_indices]
147166

148167
# default joint targets
149168
default_actions = scene.articulations["robot"].data.default_joint_pos.clone()
@@ -185,6 +204,7 @@ def test_frame_transformer_feet_wrt_base(self):
185204
source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w
186205
feet_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w
187206
feet_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w
207+
188208
# check if they are same
189209
torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf, rtol=1e-3, atol=1e-3)
190210
torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf, rtol=1e-3, atol=1e-3)
@@ -302,6 +322,87 @@ def test_frame_transformer_feet_wrt_thigh(self):
302322
torch.testing.assert_close(feet_pos_source_tf[:, index], foot_pos_b, rtol=1e-3, atol=1e-3)
303323
torch.testing.assert_close(feet_quat_source_tf[:, index], foot_quat_b, rtol=1e-3, atol=1e-3)
304324

325+
def test_frame_transformer_body_wrt_cube(self):
326+
"""Test body transformation w.r.t. base source frame.
327+
328+
In this test, the source frame is the robot base.
329+
330+
The target_frame is a cube in the scene.
331+
"""
332+
# Spawn things into stage
333+
scene_cfg = MySceneCfg(num_envs=2, env_spacing=5.0, lazy_sensor_update=False)
334+
scene_cfg.frame_transformer = FrameTransformerCfg(
335+
prim_path="{ENV_REGEX_NS}/Robot/base",
336+
target_frames=[
337+
FrameTransformerCfg.FrameCfg(
338+
name="CUBE_USER",
339+
prim_path="{ENV_REGEX_NS}/cube",
340+
),
341+
],
342+
)
343+
scene = InteractiveScene(scene_cfg)
344+
345+
# Play the simulator
346+
self.sim.reset()
347+
348+
# default joint targets
349+
default_actions = scene.articulations["robot"].data.default_joint_pos.clone()
350+
# Define simulation stepping
351+
sim_dt = self.sim.get_physics_dt()
352+
# Simulate physics
353+
for count in range(100):
354+
# # reset
355+
if count % 25 == 0:
356+
# reset root state
357+
root_state = scene.articulations["robot"].data.default_root_state.clone()
358+
root_state[:, :3] += scene.env_origins
359+
joint_pos = scene.articulations["robot"].data.default_joint_pos
360+
joint_vel = scene.articulations["robot"].data.default_joint_vel
361+
# -- set root state
362+
# -- robot
363+
scene.articulations["robot"].write_root_state_to_sim(root_state)
364+
scene.articulations["robot"].write_joint_state_to_sim(joint_pos, joint_vel)
365+
# reset buffers
366+
scene.reset()
367+
368+
# set joint targets
369+
robot_actions = default_actions + 0.5 * torch.randn_like(default_actions)
370+
scene.articulations["robot"].set_joint_position_target(robot_actions)
371+
# write data to sim
372+
scene.write_data_to_sim()
373+
# perform step
374+
self.sim.step()
375+
# read data from sim
376+
scene.update(sim_dt)
377+
378+
# check absolute frame transforms in world frame
379+
# -- ground-truth
380+
root_pose_w = scene.articulations["robot"].data.root_state_w[:, :7]
381+
cube_pos_w_gt = scene.rigid_objects["cube"].data.root_state_w[:, :3]
382+
cube_quat_w_gt = scene.rigid_objects["cube"].data.root_state_w[:, 3:7]
383+
# -- frame transformer
384+
source_pos_w_tf = scene.sensors["frame_transformer"].data.source_pos_w
385+
source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w
386+
cube_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w.squeeze()
387+
cube_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w.squeeze()
388+
389+
# check if they are same
390+
torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf, rtol=1e-3, atol=1e-3)
391+
torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf, rtol=1e-3, atol=1e-3)
392+
torch.testing.assert_close(cube_pos_w_gt, cube_pos_w_tf, rtol=1e-3, atol=1e-3)
393+
torch.testing.assert_close(cube_quat_w_gt, cube_quat_w_tf, rtol=1e-3, atol=1e-3)
394+
395+
# check if relative transforms are same
396+
cube_pos_source_tf = scene.sensors["frame_transformer"].data.target_pos_source
397+
cube_quat_source_tf = scene.sensors["frame_transformer"].data.target_quat_source
398+
# ground-truth
399+
cube_pos_b, cube_quat_b = math_utils.subtract_frame_transforms(
400+
root_pose_w[:, :3], root_pose_w[:, 3:], cube_pos_w_tf, cube_quat_w_tf
401+
)
402+
# check if they are same
403+
torch.testing.assert_close(cube_pos_source_tf[:, 0], cube_pos_b, rtol=1e-3, atol=1e-3)
404+
torch.testing.assert_close(cube_quat_source_tf[:, 0], cube_quat_b, rtol=1e-3, atol=1e-3)
405+
305406

306407
if __name__ == "__main__":
307408
run_tests()

0 commit comments

Comments
 (0)